diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 64c538367267d..a9007c8db3078 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -154,4 +154,121 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", let hasVerifier = 1; } +def Linalg_WinogradFilterTransformOp : + Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> { + let summary = "Winograd filter transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of filter + transformation (G x g x G^T) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter, + TensorRankOf<[AnyType], [4]>:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs TensorRankOf<[AnyType], [4]>:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $filter `:` type($filter) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradInputTransformOp : + Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> { + let summary = "Winograd input transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of input + transformation (B^T x d x B) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins TensorRankOf<[AnyType], [4]>:$input, + TensorRankOf<[AnyType], [6]>:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs TensorRankOf<[AnyType], [6]>:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradOutputTransformOp : + Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> { + let summary = "Winograd output transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of output + transformation (A^T x y x A) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins TensorRankOf<[AnyType], [6]>:$value, + TensorRankOf<[AnyType], [4]>:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs TensorRankOf<[AnyType], [4]>:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $value `:` type($value) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 693fca4f63502..80b1f2ec363eb 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1735,6 +1735,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); +/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r); + /// Adds patterns that reduce the rank of named contraction ops that have /// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`, /// ``, `expand_shape` (if on tensors). For example a diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 0754bd95a90f7..cefaad9b22653 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2739,6 +2739,122 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { return SmallVector{result}; } +//===----------------------------------------------------------------------===// +// WinogradFilterTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradFilterTransformOp::verify() { + auto filterType = cast(getFilter().getType()); + ArrayRef filterShape = filterType.getShape(); + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t r = getR(); + int64_t m = getM(); + + if (filterH != r && filterH != 1) + return emitOpError("expect filter height either equals to r or 1"); + if (filterW != r && filterW != 1) + return emitOpError("expect filter width either equals to r or 1"); + if (filterH == 1 && filterW == 1) + return emitOpError("expect either filter height or width equals to r"); + + SmallVector expectedOutputShape; + expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1); + expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1); + expectedOutputShape.push_back(filterShape[3]); + expectedOutputShape.push_back(filterShape[0]); + + auto outputType = cast(getOutput().getType()); + ArrayRef outputShape = outputType.getShape(); + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { + return emitOpError("the output shape is not expected"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradInputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradInputTransformOp::verify() { + auto inputType = cast(getInput().getType()); + ArrayRef inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int m = getM(); + int r = getR(); + int64_t tileSize = m + r - 1; + bool leftTransform = inputH != 1; + bool rightTransform = inputW != 1; + + SmallVector expectedOutputShape(6, inputH); + if (ShapedType::isDynamic(inputH)) { + expectedOutputShape[0] = tileSize; + expectedOutputShape[2] = ShapedType::kDynamic; + } else { + expectedOutputShape[0] = leftTransform ? tileSize : 1; + expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1; + } + if (ShapedType::isDynamic(inputW)) { + expectedOutputShape[1] = tileSize; + expectedOutputShape[3] = ShapedType::kDynamic; + } else { + expectedOutputShape[1] = rightTransform ? tileSize : 1; + expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1; + } + expectedOutputShape[4] = inputShape[0]; + expectedOutputShape[5] = inputShape[3]; + + auto outputType = cast(getOutput().getType()); + ArrayRef outputShape = outputType.getShape(); + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { + return emitOpError("the output shape is not expected"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// WinogradOutputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradOutputTransformOp::verify() { + auto valueType = cast(getValue().getType()); + ArrayRef valueShape = valueType.getShape(); + int64_t valueH = valueShape[0]; + int64_t valueW = valueShape[1]; + int64_t valueTileH = valueShape[2]; + int64_t valueTileW = valueShape[3]; + int m = getM(); + int r = getR(); + bool leftTransform = valueH != 1; + bool rightTransform = valueW != 1; + + SmallVector expectedOutputShape(4, valueH); + if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) { + expectedOutputShape[1] = ShapedType::kDynamic; + } else { + if (valueH != (leftTransform ? m + r - 1 : 1)) + return emitOpError("expect input height equals to input tile size"); + expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH; + } + if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) { + expectedOutputShape[2] = ShapedType::kDynamic; + } else { + if (valueW != (rightTransform ? m + r - 1 : 1)) + return emitOpError("expect input width equals to input tile size"); + expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW; + } + expectedOutputShape[0] = valueShape[4]; + expectedOutputShape[3] = valueShape[5]; + + auto outputType = cast(getOutput().getType()); + ArrayRef outputShape = outputType.getShape(); + if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) { + return emitOpError("the output shape is not expected"); + } + return success(); +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 7e3dc56e0acdc..a7dcc29b5b9be 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Transforms.cpp TransposeConv2D.cpp Vectorization.cpp + WinogradConv2D.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp new file mode 100644 index 0000000000000..351549bf2b434 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -0,0 +1,329 @@ +//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implement Winograd Conv2D algorithm. The implementation is based on the +// paper: Fast Algorithms for Convolutional Neural Networks +// (https://arxiv.org/abs/1509.09308) +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace linalg { + +namespace { + +using TransformMapKeyTy = std::pair; + +/// We use F(m, r) to define the size of minimal filtering algorithms. +/// m is the output dimension and r is the filter dimension. We can get +/// the input dimension, alpha, from the formula, alpha = m + r - 1. +/// +/// For example, when m = 2 and r = 3, we know its input size is 4. +/// The Conv2D will operate on 4x4 input data with 3x3 filter and get +/// 2x2 output result. +constexpr TransformMapKeyTy F_2_3{2, 3}; +constexpr TransformMapKeyTy F_4_3{4, 3}; +constexpr TransformMapKeyTy F_2_5{2, 5}; + +/// This function generates linalg.batch_matmul to multiply input with filter. +/// linalg.batch_matmul only supports 3-dimensional inputs. We can treat +/// tileH x tileW x H x W data as the 1-dimensional data array. That is to +/// convert [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this +/// way, we can convert 6-dimensional inputs to 3-dimensional representation +/// that is suitable for linalg.batch_matmul. +/// +/// Batched matmul will do the matrix multiply with the reduction on channel. +/// +/// We get +/// +/// %collapsed_input = tensor.collapse_shape %input +/// %collapsed_filter = tensor.collapse_shape %filter +/// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter +/// %expanded_ret = tensor.expand_shape %ret +/// +/// After this function, we get return value with data layout +/// (tileH, tileW, H, W, N, F). +static Value matrixMultiply(RewriterBase &rewriter, Location loc, + Value transformedFilter, Value transformedInput, + Type outputElementType) { + // Convert (alphaH, alphaW, C, F) to (alphaH x alphaW, C, F) for filter. + auto filterType = cast(transformedFilter.getType()); + assert(filterType.hasStaticShape() && "only support static shapes."); + ArrayRef filterShape = filterType.getShape(); + Type filterElementType = filterType.getElementType(); + auto filterReassocType = RankedTensorType::get( + {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, + filterElementType); + SmallVector filterReassoc = {{0, 1}, {2}, {3}}; + Value collapseFilter = rewriter.create( + loc, filterReassocType, transformedFilter, filterReassoc); + + // Convert (alphaH, alphaW, tileH, tileW, N, C) to + // (alphaH x alphaW, tileH x tileW x N, C) for input. + auto inputType = cast(transformedInput.getType()); + assert(inputType.hasStaticShape() && "only support static shapes."); + ArrayRef inputShape = inputType.getShape(); + Type inputElementType = inputType.getElementType(); + auto inputReassocType = RankedTensorType::get( + {inputShape[0] * inputShape[1], + inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, + inputElementType); + SmallVector inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; + Value collapseInput = rewriter.create( + loc, inputReassocType, transformedInput, inputReassoc); + + // Batched matrix multiply. + auto matmulType = RankedTensorType::get( + {inputShape[0] * inputShape[1], + inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, + outputElementType); + Value init = rewriter.create(loc, matmulType.getShape(), + outputElementType); + + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange({collapseInput, collapseFilter}), + ValueRange{init}); + + // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F) + // Expand matmul result to (alphaH, alphaW, tileH, tileW, N, F). + SmallVector outputReassoc = {{0, 1}, {2, 3, 4}, {5}}; + auto outputReassocType = + RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], + inputShape[3], inputShape[4], filterShape[3]}, + outputElementType); + auto expandOutput = rewriter.create( + loc, outputReassocType, matmulOp.getResult(0), outputReassoc); + return expandOutput; +} + +/// Create an empty tensor with alignedType and insert the value into the +/// created empty tensor with aligned size. +static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, + Value value, ArrayRef alignedShape) { + auto valueType = cast(value.getType()); + Type elementType = valueType.getElementType(); + auto alignedType = RankedTensorType::get(alignedShape, elementType); + Value padValue = rewriter.create( + loc, elementType, rewriter.getZeroAttr(elementType)); + + return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value, + padValue, false); +} + +/// Extract sub-tensor with extractedType from value. +static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, + Value value, + RankedTensorType extractedType) { + OpFoldResult zeroIndex = rewriter.getIndexAttr(0); + OpFoldResult oneIndex = rewriter.getIndexAttr(1); + SmallVector offsets(4, zeroIndex); + SmallVector strides(4, oneIndex); + + ArrayRef extractedShape = extractedType.getShape(); + SmallVector sizes = + getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape)); + + return rewriter.create(loc, extractedType, value, + offsets, sizes, strides); +} + +/// Utility function to check all values in the attribute are 1. +static bool hasAllOneValues(DenseIntElementsAttr attr) { + return llvm::all_of( + attr, [](const APInt &element) { return element.getSExtValue() == 1; }); +} + +/// A helper function to convert linalg.conv_2d_nhwc_fhwc to +/// linalg.winograd_*_transform ops. +static FailureOr +winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, + int64_t m, int64_t r) { + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + auto inputType = cast(input.getType()); + auto filterType = cast(filter.getType()); + auto outputType = cast(output.getType()); + + // TODO: Should we support dynamic shapes? + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + if (!hasAllOneValues(convOp.getStrides())) + return rewriter.notifyMatchFailure(convOp, "expected all ones for strides"); + + ArrayRef filterShape = filterType.getShape(); + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + ArrayRef inputShape = inputType.getShape(); + int64_t inputN = inputShape[0]; + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t inputC = inputShape[3]; + ArrayRef outputShape = outputType.getShape(); + int64_t outputN = outputShape[0]; + int64_t outputH = outputShape[1]; + int64_t outputW = outputShape[2]; + int64_t outputF = outputShape[3]; + + // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r). + bool isSupportedFilter = false; + if (filterH == filterW && filterH == r) + isSupportedFilter = true; + if (filterH == r && filterW == 1) + isSupportedFilter = true; + if (filterH == 1 && filterW == r) + isSupportedFilter = true; + + if (!isSupportedFilter) + return rewriter.notifyMatchFailure( + convOp, "only support filter (r x r), (r x 1) or (1 x r)"); + + // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5). + static const llvm::SmallVector validConfigs = { + F_2_3, F_4_3, F_2_5}; + + TransformMapKeyTy key = {m, r}; + auto it = std::find(validConfigs.begin(), validConfigs.end(), key); + // If we cannot find the constant transformation matrix, it means we do + // not support this configuration yet. + if (it == validConfigs.end()) + return failure(); + + // All the criterias are satisfied. We can do Winograd Conv2D. + Location loc = convOp.getLoc(); + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + int64_t heightM = leftTransform ? m : 1; + int64_t widthM = rightTransform ? m : 1; + int64_t heightR = leftTransform ? r : 1; + int64_t widthR = rightTransform ? r : 1; + + // --- Create operation for filter transform --- + Type filterElementType = filterType.getElementType(); + int64_t alphaH = heightM + heightR - 1; + int64_t alphaW = widthM + widthR - 1; + int64_t tileH = llvm::divideCeilSigned(outputH, heightM); + int64_t tileW = llvm::divideCeilSigned(outputW, widthM); + auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF}, + filterElementType); + Value retValue = rewriter.create(loc, retType.getShape(), + filterElementType); + auto transformedFilter = rewriter.create( + loc, retType, filter, retValue, m, r); + + // --- Create operation for input transform --- + + // When input size - (r - 1) is not aligned with output tile size, we need to + // pad the input data to create the full tiles as tiling. + Type inputElementType = inputType.getElementType(); + int64_t alignedInputH = tileH * heightM + (heightR - 1); + int64_t alignedInputW = tileW * widthM + (widthR - 1); + if (alignedInputH != inputH || alignedInputW != inputW) { + input = padToAlignedTensor(rewriter, loc, input, + {inputN, alignedInputH, alignedInputW, inputC}); + } + + retType = RankedTensorType::get( + {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType); + retValue = rewriter.create(loc, retType.getShape(), + inputElementType); + auto transformedInput = rewriter.create( + loc, retType, input, retValue, m, r); + + Type outputElementType = outputType.getElementType(); + Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter, + transformedInput, outputElementType); + + // --- Create operation for output transform --- + + // When output size is not aligned with output tile size, we need to pad the + // output buffer to insert the full tiles after tiling. + int64_t alignedOutputH = tileH * heightM; + int64_t alignedOutputW = tileW * widthM; + bool isOutputUnaligned = + ((alignedOutputH != outputH) || (alignedOutputW != outputW)); + if (isOutputUnaligned) { + auto alignedOutputType = RankedTensorType::get( + {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType); + output = + padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape()); + outputType = alignedOutputType; + } + + Value transformedOutput = rewriter.create( + loc, outputType, matmulRet, output, m, r); + + // When output size is not aligned with output tile size, extract the + // value from the padded buffer. + if (isOutputUnaligned) { + transformedOutput = extractFromAlignedTensor( + rewriter, loc, transformedOutput, + RankedTensorType::get({outputN, outputH, outputW, outputF}, + outputElementType)); + } + + rewriter.replaceOp(convOp, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +/// A rewrite pattern for Winograd Conv2D algorithm. +class WinogradConv2DNhwcFhwc final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) + : OpRewritePattern(context), m(m), r(r) {} + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, + PatternRewriter &rewriter) const override { + if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) + return failure(); + + return success(); + } + +private: + int64_t m; + int64_t r; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r) { + MLIRContext *context = patterns.getContext(); + // TODO: Support more Conv2D data layout, e.g., conv_2d_nchw_fchw + patterns.insert(context, m, r); +} + +} // end namespace linalg +} // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 213ef6c7b2616..c481a723c5623 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -855,3 +855,122 @@ func.func @mixed_semantics(%a: tensor, %b: tensor, %c: memref< return } +// ----- + +func.func @winograd_filter_transform_height(%arg0: tensor<2x4x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> { + // expected-error @+1 {{expect filter height either equals to r or 1}} + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x4x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + return %0 : tensor<6x6x5x2xf32> +} + +// ----- + +func.func @winograd_filter_transform_width(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> { + // expected-error @+1 {{expect filter width either equals to r or 1}} + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x4x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + return %0 : tensor<6x6x5x2xf32> +} + +// ----- + +func.func @winograd_filter_transform(%arg0: tensor<2x1x1x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> { + // expected-error @+1 {{expect either filter height or width equals to r}} + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x1x1x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + return %0 : tensor<6x6x5x2xf32> +} + +// ----- + +func.func @winograd_filter_dyn(%arg0: tensor, %arg1: tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor) outs(%arg1 : tensor<6x5x?x?xf32>) -> tensor<6x5x?x?xf32> + return %0 : tensor<6x5x?x?xf32> +} + +// ----- + +func.func @winograd_input_transform_height(%arg0: tensor<2x13x14x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x13x14x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> + return %0 : tensor<6x6x3x3x2x5xf32> +} + +// ----- + +func.func @winograd_input_transform_width(%arg0: tensor<2x14x13x5xf32>, %arg1: tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x13x5xf32>) outs(%arg1 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> + return %0 : tensor<6x6x3x3x2x5xf32> +} + +// ----- + +func.func @winograd_input_transform_output_tileH(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x2x3x2x5xf32>) -> tensor<6x6x2x3x2x5xf32> + return %0 : tensor<6x6x2x3x2x5xf32> +} + +// ----- + +func.func @winograd_input_transform_output_tileW(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x6x3x2x2x5xf32>) -> tensor<6x6x3x2x2x5xf32> + return %0 : tensor<6x6x3x2x2x5xf32> +} + +// ----- + +func.func @winograd_input_transform_output_height(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<5x6x3x3x2x5xf32>) -> tensor<5x6x3x3x2x5xf32> + return %0 : tensor<5x6x3x3x2x5xf32> +} + +// ----- + +func.func @winograd_input_transform_output_width(%arg0: tensor<2x14x14x5xf32>, %arg1: tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x14x14x5xf32>) outs(%arg1 : tensor<6x5x3x3x2x5xf32>) -> tensor<6x5x3x3x2x5xf32> + return %0 : tensor<6x5x3x3x2x5xf32> +} + +// ----- + +func.func @winograd_input_dyn(%arg0: tensor, %arg1: tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor) outs(%arg1 : tensor<6x5x?x?x?x?xf32>) -> tensor<6x5x?x?x?x?xf32> + return %0 : tensor<6x5x?x?x?x?xf32> +} + +// ----- + +func.func @winograd_output_transform_input_height(%arg0: tensor<5x6x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> { + // expected-error @+1 {{expect input height equals to input tile size}} + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<5x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> + return %0 : tensor<2x12x12x2xf32> +} + +// ----- + +func.func @winograd_output_transform_input_width(%arg0: tensor<6x5x3x3x2x2xf32>, %arg1: tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> { + // expected-error @+1 {{expect input width equals to input tile size}} + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x5x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> + return %0 : tensor<2x12x12x2xf32> +} + +// ----- + +func.func @winograd_output_transform_output_height(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x11x12x2xf32>) -> tensor<2x11x12x2xf32> + return %0 : tensor<2x11x12x2xf32> +} + +// ----- + +func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>, %arg1: tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> { + // expected-error @+1 {{the output shape is not expected}} + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32> + return %0 : tensor<2x12x11x2xf32> +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index b422066aade64..146e9780b8ebb 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -613,3 +613,54 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK-SAME: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> // CHECK: return %[[D1]] : tensor<2x16x32xf32> // CHECK: } + +// ----- + +func.func @winograd(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<6x6x5x2xf32> + %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + %2 = tensor.empty() : tensor<6x6x1x1x2x5xf32> + %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%2 : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32> + %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> + %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32> + %4 = tensor.empty() : tensor<36x2x2xf32> + %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> + %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32> + %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x1x1x2x2xf32>) outs(%arg3 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %6 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: func @winograd +// CHECK: linalg.winograd_filter_transform m(4) r(3) +// CHECK: linalg.winograd_input_transform m(4) r(3) +// CHECK: linalg.winograd_output_transform m(4) r(3) + +// ----- + +func.func @winograd_filter_dyn(%arg0: tensor, %arg1: tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> { + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> + return %0 : tensor<6x6x?x?xf32> +} + +// CHECK-LABEL: func @winograd_filter_dyn +// CHECK: linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor) outs(%arg1 : tensor<6x6x?x?xf32>) -> tensor<6x6x?x?xf32> + +// ----- + +func.func @winograd_input_dyn(%arg0: tensor, %arg1: tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> { + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> + return %0 : tensor<6x6x?x?x?x?xf32> +} + +// CHECK-LABEL: func @winograd_input_dyn +// CHECK: linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor) outs(%arg1 : tensor<6x6x?x?x?x?xf32>) -> tensor<6x6x?x?x?x?xf32> + +// ----- + +func.func @winograd_output_dyn(%arg0: tensor<6x6x?x?x?x?xf32>, %arg1: tensor) -> tensor { + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @winograd_output_dyn +// CHECK: linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor) -> tensor diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir new file mode 100644 index 0000000000000..ec11a6ef8fbee --- /dev/null +++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir @@ -0,0 +1,193 @@ +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s + +func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %0 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_4x4_3x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%out : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %0 : tensor<2x2x2x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_2x2_5x5 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%out : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> + return %0 : tensor<2x1x4x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_1x4_1x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x6x5x2xf32>) -> tensor<1x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x6x1x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x6x1x1x2x5xf32>) -> tensor<1x6x1x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<1x6x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<1x6x1x1x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [1, 6, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x6x1x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%out : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> + return %0 : tensor<2x4x1x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_4x1_3x1 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x1x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x1x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%out : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_aligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x8x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x8x2xf32>) -> tensor<36x8x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x2x2x2x2xf32>) outs(%[[ARG3]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %0 : tensor<2x9x9x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S0]] : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] { +// CHECK-NEXT: ^bb0 +// CHECK-NEXT: tensor.yield %[[CST]] : f32 +// CHECK-NEXT: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[PADDED]] : tensor<2x14x14x5xf32>) outs(%[[S2]] : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %3 {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32> +// CHECK-NEXT: %[[PADDED_1:.*]] = tensor.pad %arg3 low[0, 0, 0, 0] high[0, 3, 3, 0] { +// CHECK-NEXT: ^bb0 +// CHECK-NEXT: tensor.yield %[[CST]] : f32 +// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x3x3x2x2xf32>) outs(%[[PADDED_1]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_type_promotion(%arg0: tensor<2x6x6x5xf16>, %arg1: tensor<2x3x3x5xf16>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf16>, tensor<2x3x3x5xf16>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %0 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: func.func @conv2d_type_promotion +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf16>, %[[ARG1:.*]]: tensor<2x3x3x5xf16>, %[[ARG2:.*]]: tensor<1xf32>, %[[ARG3:.*]]: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf16> +// CHECK-NEXT: %[[S1:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf16>) outs(%[[S0]] : tensor<6x6x5x2xf16>) -> tensor<6x6x5x2xf16> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x1x1x2x5xf16> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf16>) outs(%[[S2]] : tensor<6x6x1x1x2x5xf16>) -> tensor<6x6x1x1x2x5xf16> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf16> into tensor<36x5x2xf16> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x1x1x2x5xf16> into tensor<36x2x5xf16> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf16>, tensor<36x5x2xf16>) outs(%[[S4]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 1, 1, 2, 2] : tensor<36x2x2xf32> into tensor<6x6x1x1x2x2xf32> +// CHECK-NEXT: %[[S6:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<6x6x1x1x2x2xf32>) outs(%[[ARG3]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> +// CHECK-NEXT: return %[[S6]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %0 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_1 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>, %out: tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%out : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %0 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_2 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_3(%arg0: tensor, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<2x3x3x5xf32>) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: conv2d_unsupported_3 +// CHECK: linalg.conv_2d_nhwc_fhwc diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 4892fa2f99a7c..12cb46a5968f1 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -123,6 +123,10 @@ struct TestLinalgTransforms *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; + Option testWinogradConv2D{ + *this, "test-winograd-conv2d", + llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), + llvm::cl::init(false)}; }; } // namespace @@ -207,6 +211,13 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyWinogradConv2D(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); + populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -231,6 +242,8 @@ void TestLinalgTransforms::runOnOperation() { return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); + if (testWinogradConv2D) + return applyWinogradConv2D(getOperation()); } namespace mlir {