diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 866275cedf68b..ecc86999006db 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2646,4 +2646,55 @@ def MapCopyToThreadsOp : }]; } +//===----------------------------------------------------------------------===// +// Winograd Conv2D +//===----------------------------------------------------------------------===// + +def WinogradConv2DOp : Op { + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operation into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + #### Return modes: + + This operation produces a silenceable failure if `target` is unsupported. + Otherwise, the operation succeeds and returns a handle of the sequence that + replaces the original convolution. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$m, + I64Attr:$r); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 80b1f2ec363eb..0c7a8edff222f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1332,6 +1332,13 @@ FailureOr transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS = true); +/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm +/// F(m x m, r x r). m is the dimension size of output and r is the dimension +/// size of filter. +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a @@ -1739,6 +1746,9 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r); +/// Patterns to decompose Winograd operators. +void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); + /// Adds patterns that reduce the rank of named contraction ops that have /// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`, /// ``, `expand_shape` (if on tensors). For example a diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 4eb334f8bbbfa..bffe7a4e7d62c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3711,6 +3711,37 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// WinogradConv2DOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + FailureOr maybeTransformed = failure(); + bool supported = TypeSwitch(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + maybeTransformed = + winogradConv2D(rewriter, op, getM(), getR()); + return true; + }) + .Default([&](Operation *op) { return false; }); + + if (!supported) { + return emitSilenceableError() + << "this operation is not supported to convert to Winograd Conv2D"; + } + + if (supported && failed(maybeTransformed)) { + return emitSilenceableError() << "apply Winograd Conv2D failed"; + } + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 76742f2a824e7..754f832e98eea 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -12,10 +12,14 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/MathExtras.h" namespace mlir { @@ -23,6 +27,156 @@ namespace linalg { namespace { +// clang-format off +/// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its +/// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), +/// m is the output dimension and r is the filter dimension, is +/// +/// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A +/// +/// g is filter and d is input data. We need to prepare 6 constant +/// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. +/// +/// The following tables define these constant transformation matrices for +/// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) +constexpr float G_2x2_3x3[] = { + -1, 0, 0, + 1./2, -1./2, 1./2, + 1./2, 1./2, 1./2, + 0, 0, 1 +}; + +constexpr float GT_2x2_3x3[] = { + -1, 1./2, 1./2, 0, + 0, -1./2, 1./2, 0, + 0, 1./2, 1./2, 1 +}; + +constexpr float BT_2x2_3x3[] = { + -1, 0, 1, 0, + 0, -1, 1, 0, + 0, 1, 1, 0, + 0, -1, 0, 1 +}; + +constexpr float B_2x2_3x3[] = { + -1, 0, 0, 0, + 0, -1, 1, -1, + 1, 1, 1, 0, + 0, 0, 0, 1 +}; + +constexpr float AT_2x2_3x3[] = { + 1, 1, 1, 0, + 0, -1, 1, 1 +}; + +constexpr float A_2x2_3x3[] = { + 1, 0, + 1, -1, + 1, 1, + 0, 1 +}; + +constexpr float G_4x4_3x3[] = { + 1, 0, 0, + -1./3, 1./3, -1./3, + -1./3, -1./3, -1./3, + 1./12, -1./6, 1./3, + 1./12, 1./6, 1./3, + 0, 0, 1 +}; + +constexpr float GT_4x4_3x3[] = { + 1, -1./3, -1./3, 1./12, 1./12, 0, + 0, 1./3, -1./3, -1./6, 1./6, 0, + 0, -1./3, -1./3, 1./3, 1./3, 1 +}; + +constexpr float BT_4x4_3x3[] = { + 1./4, 0, -5./16, 0, 1./16, 0, + 0, 1./4, -1./4, -1./16, 1./16, 0, + 0, -1./4, -1./4, 1./16, 1./16, 0, + 0, 1./4, -1./8, -1./4, 1./8, 0, + 0, -1./4, -1./8, 1./4, 1./8, 0, + 0, 1./4, 0, -5./16, 0, 1./16 +}; + +constexpr float B_4x4_3x3[] = { + 1./4, 0, 0, 0, 0, 0, + 0, 1./4, -1./4, 1./4, -1./4, 1./4, + -5./16, -1./4, -1./4, -1./8, -1./8, 0, + 0, -1./16, 1./16, -1./4, 1./4, -5./16, + 1./16, 1./16, 1./16, 1./8, 1./8, 0, + 0, 0, 0, 0, 0, 1./16 +}; + +constexpr float AT_4x4_3x3[] = { + 1./8, 1./4, 1./4, 1./8, 1./8, 0, + 0, -1./4, 1./4, -1./4, 1./4, 0, + 0, 1./4, 1./4, 1./2, 1./2, 0, + 0, -1./4, 1./4, -1, 1, 1./2 +}; + +constexpr float A_4x4_3x3[] = { + 1./8, 0, 0, 0, + 1./4, -1./4, 1./4, -1./4, + 1./4, 1./4, 1./4, 1./4, + 1./8, -1./4, 1./2, -1, + 1./8, 1./4, 1./2, 1, + 0, 0, 0, 1./2 +}; + +constexpr float G_2x2_5x5[] = { + 1, 0, 0, 0, 0, + 1./6, -1./6, 1./6, -1./6, 1./6, + -1./6, -1./6, -1./6, -1./6, -1./6, +-4./15, 2./15, -1./15, 1./30, -1./60, + 1./60, 1./30, 1./15, 2./15, 4./15, + 0, 0, 0, 0, 1 +}; + +constexpr float GT_2x2_5x5[] = { + 1, 1./6, -1./6, -4./15, 1./60, 0, + 0, -1./6, -1./6, 2./15, 1./30, 0, + 0, 1./6, -1./6, -1./15, 1./15, 0, + 0, -1./6, -1./6, 1./30, 2./15, 0, + 0, 1./6, -1./6, -1./60, 4./15, 1 +}; + +constexpr float BT_2x2_5x5[] = { + 1./8, 3./16, -1./4, -3./16, 1./8, 0, + 0, 1./8, 1./16, -5./16, 1./8, 0, + 0, -1./8, -5./16, -1./16, 1./8, 0, + 0, 1./4, -1./8, -1./4, 1./8, 0, + 0, -1./8, -1./4, 1./8, 1./4, 0, + 0, 1./8, 3./16, -1./4, -3./16, 1./8 +}; + +constexpr float B_2x2_5x5[] = { + 1./8, 0, 0, 0, 0, 0, + 3./16, 1./8, -1./8, 1./4, -1./8, 1./8, + -1./4, 1./16, -5./16, -1./8, -1./4, 3./16, + -3./16, -5./16, -1./16, -1./4, 1./8, -1./4, + 1./8, 1./8, 1./8, 1./8, 1./4, -3./16, + 0, 0, 0, 0, 0, 1./8 +}; + +constexpr float AT_2x2_5x5[] = { + 1./2, 1, 1, 2, 1, 0, + 0, -1, 1, -1, 2, 1./2 +}; + +constexpr float A_2x2_5x5[] = { + 1./2, 0, + 1, -1, + 1, 1, + 2, -1, + 1, 2, + 0, 1./2 +}; +// clang-format on + using TransformMapKeyTy = std::pair; /// We use F(m, r) to define the size of minimal filtering algorithms. @@ -36,6 +190,408 @@ constexpr TransformMapKeyTy F_2_3{2, 3}; constexpr TransformMapKeyTy F_4_3{4, 3}; constexpr TransformMapKeyTy F_2_5{2, 5}; +/// Structure to keep information of constant transform matrices. +struct TransformMatrix { + TransformMatrix(const float *table, int64_t rows, int64_t cols, + int64_t scalarFactor = 1) + : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} + + const float *table; + int64_t rows; + int64_t cols; + int64_t scalarFactor; +}; + +/// Utility function to convert constant array to arith.constant Value. +Value create2DTransformMatrix(OpBuilder &builder, Location loc, + TransformMatrix transform, Type type) { + ArrayRef constVec(transform.table, transform.rows * transform.cols); + + return builder.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get( + SmallVector{transform.rows, transform.cols}, type), + constVec)); +} + +/// Extract height x width data from 4D tensors. +Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source, + Value loopNorFIndex, Value loopCorFIndex, + Value heightOffset, Value widthOffset, + int64_t extractHeight, int64_t extractWidth, + int64_t loopNorFIdx, int64_t loopCorFIdx, + int64_t heightIdx, int64_t widthIdx) { + auto sourceType = cast(source.getType()); + Type elementType = sourceType.getElementType(); + int64_t srcSize = sourceType.getRank(); + + auto oneIndex = builder.getIndexAttr(1); + SmallVector offsets; + offsets.resize(srcSize); + offsets[loopNorFIdx] = loopNorFIndex; + offsets[loopCorFIdx] = loopCorFIndex; + offsets[heightIdx] = heightOffset; + offsets[widthIdx] = widthOffset; + SmallVector sizes(srcSize, oneIndex); + sizes[heightIdx] = builder.getIndexAttr(extractHeight); + sizes[widthIdx] = builder.getIndexAttr(extractWidth); + SmallVector strides(srcSize, oneIndex); + + auto extractFilterType = + RankedTensorType::get({extractHeight, extractWidth}, elementType); + auto extractFilterOp = builder.create( + loc, extractFilterType, source, offsets, sizes, strides); + + return extractFilterOp; +} + +/// Extract height x width data from 6D tensors. +Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source, + Value tileHIndex, Value tileWIndex, + Value loopNorFIndex, Value loopCorFIndex, + int64_t tileHIdx, int64_t tileWIdx, + int64_t loopNorFIdx, int64_t loopCorFIdx, + int64_t heightIdx, int64_t widthIdx) { + auto sourceType = cast(source.getType()); + Type elementType = sourceType.getElementType(); + auto sourceShape = sourceType.getShape(); + int64_t srcSize = sourceType.getRank(); + int64_t height = sourceShape[heightIdx]; + int64_t width = sourceShape[widthIdx]; + + auto zeroIndex = builder.getIndexAttr(0); + auto oneIndex = builder.getIndexAttr(1); + SmallVector offsets(srcSize, zeroIndex); + offsets.resize(srcSize); + offsets[tileHIdx] = tileHIndex; + offsets[tileWIdx] = tileWIndex; + offsets[loopNorFIdx] = loopNorFIndex; + offsets[loopCorFIdx] = loopCorFIndex; + SmallVector sizes(srcSize, oneIndex); + sizes[heightIdx] = builder.getIndexAttr(height); + sizes[widthIdx] = builder.getIndexAttr(width); + SmallVector strides(srcSize, oneIndex); + + auto extractFilterType = RankedTensorType::get({height, width}, elementType); + auto extractFilterOp = builder.create( + loc, extractFilterType, source, offsets, sizes, strides); + + return extractFilterOp; +} + +/// Insert transformed height x width data to 4D tensors which it is +/// extracted from. +Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, + Value dest, Value loopNorFIndex, Value loopCorFIndex, + Value heightOffset, Value widthOffset, int64_t height, + int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx, + int64_t heightIdx, int64_t widthIdx) { + int64_t destSize = cast(dest.getType()).getRank(); + auto oneIndex = builder.getIndexAttr(1); + SmallVector retOffsets; + retOffsets.resize(destSize); + retOffsets[loopNorFIdx] = loopNorFIndex; + retOffsets[loopCorFIdx] = loopCorFIndex; + retOffsets[heightIdx] = heightOffset; + retOffsets[widthIdx] = widthOffset; + SmallVector retSizes(destSize, oneIndex); + retSizes[heightIdx] = builder.getIndexAttr(height); + retSizes[widthIdx] = builder.getIndexAttr(width); + SmallVector strides(destSize, oneIndex); + + auto insertSliceOp = builder.create( + loc, source, dest, retOffsets, retSizes, strides); + + return insertSliceOp; +} + +/// Insert transformed height x width data to 6D tensors which it is +/// extracted from. +Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, + Value dest, Value tileHIndex, Value tileWIndex, + Value loopNorFIndex, Value loopCorFIndex, int64_t height, + int64_t width, int64_t tileHIdx, int64_t tileWIdx, + int64_t loopNorFIdx, int64_t loopCorFIdx, + int64_t heightIdx, int64_t widthIdx) { + int64_t destSize = cast(dest.getType()).getRank(); + auto zeroIndex = builder.getIndexAttr(0); + auto oneIndex = builder.getIndexAttr(1); + SmallVector retOffsets(destSize, zeroIndex); + retOffsets.resize(destSize); + retOffsets[tileHIdx] = tileHIndex; + retOffsets[tileWIdx] = tileWIndex; + retOffsets[loopNorFIdx] = loopNorFIndex; + retOffsets[loopCorFIdx] = loopCorFIndex; + SmallVector retSizes(destSize, oneIndex); + retSizes[heightIdx] = builder.getIndexAttr(height); + retSizes[widthIdx] = builder.getIndexAttr(width); + SmallVector strides(destSize, oneIndex); + + auto insertSliceOp = builder.create( + loc, source, dest, retOffsets, retSizes, strides); + + return insertSliceOp; +} + +/// This function transforms the filter. The data layout of the filter is FHWC. +/// The transformation matrix is 2-dimension. We need to extract H x W from +/// FHWC first. We need to generate 2 levels of loops to iterate on F and C. +/// After the transformation, we get +/// +/// scf.for %f = lo_f to hi_f step 1 +/// scf.for %c = lo_c to hi_c step 1 +/// %extracted = extract filter from filter +/// %ret = linalg.matmul G, %extracted +/// %ret = linalg.matmul %ret, GT +/// %inserted = insert %ret into filter +Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to G transform matrix. + static const llvm::SmallDenseMap + GMatrices = { + {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, + {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, + {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, + }; + + // Map from (m, r) to GT transform matrix. + static const llvm::SmallDenseMap + GTMatrices = { + {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, + {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, + {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, + }; + + auto filterType = cast(filter.getType()); + Type elementType = filterType.getElementType(); + auto filterShape = filterType.getShape(); // F, H, W, C + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + + if (filterH != r && filterH != 1) + return Value(); + if (filterW != r && filterW != 1) + return Value(); + + Value zeroIdx = rewriter.create(loc, 0); + auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange args) -> scf::ValueVector { + Value FIter = ivs[0]; + Value CIter = ivs[1]; + + // Extract (H, W) from (F, H, W, C). + auto extractFilter = + extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx, + zeroIdx, filterH, filterW, /*loopNorFIdx=*/0, + /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + Value matmulRetValue = extractFilter; + if (leftTransform) { + // Get constant transform matrix G. + auto it = GMatrices.find(key); + if (it == GMatrices.end()) + return {}; + const TransformMatrix &GMatrix = it->second; + + retRows = GMatrix.rows; + auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); + auto init = builder.create(loc, matmulType.getShape(), + elementType); + + Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType); + // Multiply G x g. + auto matmulOp = builder.create( + loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { + // Get constant transform matrix GT. + auto it = GTMatrices.find(key); + if (it == GTMatrices.end()) + return {}; + const TransformMatrix >Matrix = it->second; + + auto matmulType = + RankedTensorType::get({retRows, GTMatrix.cols}, elementType); + auto init = builder.create(loc, matmulType.getShape(), + elementType); + + Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType); + // Multiply u = (G x g) x GT. + auto matmulOp = builder.create( + loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + // Insert (H, W) to (H, W, C, F). + int64_t retHeight = leftTransform ? m + r - 1 : 1; + int64_t retWidth = rightTransform ? m + r - 1 : 1; + + auto insertSliceOp = + insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter, + zeroIdx, zeroIdx, retHeight, retWidth, + /*loopNorFIdx=*/3, /*loopCorFIdx=*/2, + /*heightIdx=*/0, /*widthIdx=*/1); + + return {insertSliceOp}; + }; + + auto fUpperBound = rewriter.create(loc, filterF); + auto cUpperBound = rewriter.create(loc, filterC); + auto oneStep = rewriter.create(loc, 1); + scf::LoopNest loops = scf::buildLoopNest( + rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound}, + {oneStep, oneStep}, {retValue}, buildBody); + return loops.results[0]; +} + +/// This function transforms the input. The data layout of the input is NHWC. +/// The transformation matrix is 2-dimension. We need to extract H x W from +/// NHWC first. We need to generate 2 levels of loops to iterate on N and C. +/// After the transformation, we get +/// +/// scf.for %h = 0 to tileH step 1 +/// scf.for %w = 0 to tileW step 1 +/// scf.for %n = 0 to N step 1 +/// scf.for %c = 0 to C step 1 +/// %extracted = extract %extracted from +/// %input +/// at [%n, (%h x m), (%w x m), %c] +/// %ret = linalg.matmul BT, %extracted +/// %ret = linalg.matmul %ret, B +/// %inserted = insert %ret into +/// %output +/// at [0, 0, %h, %w, %n, %c] +Value inputTransform(RewriterBase &rewriter, Location loc, Value input, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to BT transform matrix. + static const llvm::SmallDenseMap + BTMatrices = { + {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)}, + {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)}, + {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)}, + }; + + // Map from (m, r) to B transform matrix. + static const llvm::SmallDenseMap + BMatrices = { + {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)}, + {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)}, + {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)}, + }; + + auto inputType = cast(input.getType()); + Type elementType = inputType.getElementType(); + auto inputShape = inputType.getShape(); // N, H, W, C + int64_t inputN = inputShape[0]; + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t inputC = inputShape[3]; + auto valueType = cast(retValue.getType()); + auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C + int64_t tileH = valueShape[2]; + int64_t tileW = valueShape[3]; + int64_t alphaH = leftTransform ? m + r - 1 : 1; + int64_t alphaW = rightTransform ? m + r - 1 : 1; + + if ((inputH != (tileH * m) + (r - 1)) && inputH != 1) + return Value(); + if ((inputW != (tileW * m) + (r - 1)) && inputW != 1) + return Value(); + + auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange args) -> scf::ValueVector { + Value tileHIter = ivs[0]; + Value tileWIter = ivs[1]; + Value NIter = ivs[2]; + Value CIter = ivs[3]; + + auto context = builder.getContext(); + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value heightOffset = + builder.create(loc, affineMap, tileHIter); + Value widthOffset = + builder.create(loc, affineMap, tileWIter); + + // Extract (H, W) from (N, H, W, C). + auto extractInput = + extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset, + widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0, + /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + int64_t retCols = 1; + Value matmulRetValue = extractInput; + if (leftTransform) { + // Get constant transform matrix BT. + auto it = BTMatrices.find(key); + if (it == BTMatrices.end()) + return {}; + const TransformMatrix &BTMatrix = it->second; + + retRows = BTMatrix.rows; + auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType); + auto init = builder.create(loc, matmulType.getShape(), + elementType); + + Value BT = + create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); + // Multiply BT x d. + auto matmulOp = builder.create( + loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { + // Get constant transform matrix B. + auto it = BMatrices.find(key); + if (it == BMatrices.end()) + return {}; + const TransformMatrix &BMatrix = it->second; + + retCols = BMatrix.cols; + auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); + auto init = builder.create(loc, matmulType.getShape(), + elementType); + Value B = + create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); + // Multiply v = (BT x d) x B. + auto matmulOp = builder.create( + loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + // Insert (H, W) to (H, W, tileH, tileW, N, C). + auto combinedVal = insert2DDataTo6D( + builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter, + CIter, retRows, retCols, 2, 3, /*loopNorFIdx=*/4, /*loopCorFIdx=*/5, + /*heightIdx=*/0, /*widthIdx=*/1); + + return {combinedVal}; + }; + + auto zeroIdx = rewriter.create(loc, 0); + auto tileHBound = rewriter.create(loc, tileH); + auto tileWBound = rewriter.create(loc, tileW); + auto nUpperBound = rewriter.create(loc, inputN); + auto cUpperBound = rewriter.create(loc, inputC); + auto oneStep = rewriter.create(loc, 1); + scf::LoopNest loops = scf::buildLoopNest( + rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, + {tileHBound, tileWBound, nUpperBound, cUpperBound}, + {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody); + return loops.results[0]; +} + /// This function generates linalg.batch_matmul to multiply input with filter. /// linalg.batch_matmul only supports 3-dimensional inputs. We can treat /// tileH x tileW x H x W data as the 1-dimensional data array. That is to @@ -107,6 +663,185 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, return expandOutput; } +/// This function transforms the output. The data layout of the output is HWNF. +/// The transformation matrix is 2-dimension. We need to extract H x W from +/// HWNF first. We need to generate 2 levels of loops to iterate on N and F. +/// After the transformation, we get +/// +/// scf.for %h = 0 to tileH step 1 +/// scf.for %w = 0 to tileW step 1 +/// scf.for %n = 0 to N step 1 +/// scf.for %f = 0 to F step 1 +/// %extracted = extract %extracted from +/// %input +/// at [0, 0, %h, %w, %n, %f] +/// %ret = linalg.matmul AT, %extracted +/// %ret = linalg.matmul %ret, A +/// %inserted = insert %ret into +/// output +/// at [%n, (%h x m), (%w x m), %f] +Value outputTransform(RewriterBase &rewriter, Location loc, Value value, + Value output, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to AT transform matrix. + static const llvm::SmallDenseMap + ATMatrices = { + {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, + {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, + {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, + }; + + // Map from (m, r) to A transform matrix. + static const llvm::SmallDenseMap + AMatrices = { + {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, + {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, + {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, + }; + + auto valueType = cast(value.getType()); + Type elementType = valueType.getElementType(); + auto valueShape = valueType.getShape(); // H, W, TileH, TileW, N, F + int64_t valueH = valueShape[0]; + int64_t valueW = valueShape[1]; + int64_t valueN = valueShape[4]; + int64_t valueF = valueShape[5]; + int64_t alphaH = leftTransform ? m + r - 1 : 1; + int64_t alphaW = rightTransform ? m + r - 1 : 1; + + if (valueH != alphaH && valueH != 1) + return Value(); + if (valueW != alphaW && valueW != 1) + return Value(); + + auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange args) -> scf::ValueVector { + Value tileHIter = ivs[0]; + Value tileWIter = ivs[1]; + Value NIter = ivs[2]; + Value FIter = ivs[3]; + + // Extract (H, W) from (H, W, tileH, tileW, N, F). + auto extractValue = + extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter, + FIter, 2, 3, /*loopNorFIdx=*/4, + /*loopCorFIdx=*/5, /*heightIdx=*/0, /*widthIdx=*/1); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + int64_t retCols = 1; + int64_t leftScalarFactor = 1; + int64_t rightScalarFactor = 1; + Value matmulRetValue = extractValue; + if (leftTransform) { + // Get constant transform matrix AT. + auto it = ATMatrices.find(key); + if (it == ATMatrices.end()) + return {}; + const TransformMatrix &ATMatrix = it->second; + + leftScalarFactor = ATMatrix.scalarFactor; + retRows = ATMatrix.rows; + auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); + auto init = builder.create(loc, matmulType.getShape(), + elementType); + + Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType); + // Multiply AT x m. + auto matmulOp = builder.create( + loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { + // Get constant transform matrix T. + auto it = AMatrices.find(key); + if (it == AMatrices.end()) + return {}; + const TransformMatrix &AMatrix = it->second; + + rightScalarFactor = AMatrix.scalarFactor; + auto matmulType = + RankedTensorType::get({retRows, AMatrix.cols}, elementType); + retCols = AMatrix.cols; + auto init = builder.create(loc, matmulType.getShape(), + elementType); + + Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType); + // Multiply y = (AT x m) x A. + auto matmulOp = builder.create( + loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + if (leftScalarFactor * rightScalarFactor != 1) { + // Multiply scalar factor. + Value scalarFactor = builder.create( + loc, + FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor)); + auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); + auto init = builder.create(loc, matmulType.getShape(), + elementType); + + auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); + SmallVector affineMaps = { + AffineMap::get(2, 0, init.getContext()), identityAffineMap}; + auto broadcastedScalar = + rewriter + .create( + loc, matmulType, ValueRange{scalarFactor}, ValueRange{init}, + affineMaps, + llvm::ArrayRef{ + utils::IteratorType::parallel, + utils::IteratorType::parallel}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + nestedBuilder.create(nestedLoc, args[0]); + }) + .getResult(0); + + matmulRetValue = builder + .create( + loc, matmulType, + ValueRange{broadcastedScalar, matmulRetValue}, + ValueRange{init}) + .getResult(0); + } + + auto context = builder.getContext(); + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value heightOffset = + builder.create(loc, affineMap, tileHIter); + Value widthOffset = + builder.create(loc, affineMap, tileWIter); + + // Insert (H, W) to (N, H, W, F). + Value combinedVal = + insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter, + heightOffset, widthOffset, retRows, retCols, + /*loopNorFIdx=*/0, + /*loopCorFIdx=*/3, /*heightIdx=*/1, + /*widthIdx=*/2); + + return {combinedVal}; + }; + + int64_t tilwH = valueShape[2]; + int64_t tileW = valueShape[3]; + auto zeroIdx = rewriter.create(loc, 0); + auto tileHBound = rewriter.create(loc, tilwH); + auto tileWBound = rewriter.create(loc, tileW); + auto nUpperBound = rewriter.create(loc, valueN); + auto fUpperBound = rewriter.create(loc, valueF); + auto oneStep = rewriter.create(loc, 1); + scf::LoopNest loops = scf::buildLoopNest( + rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, + {tileHBound, tileWBound, nUpperBound, fUpperBound}, + {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody); + return loops.results[0]; +} + /// Create an empty tensor with alignedType and insert the value into the /// created empty tensor with aligned size. static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, @@ -156,7 +891,6 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, auto filterType = cast(filter.getType()); auto outputType = cast(output.getType()); - // TODO: Should we support dynamic shapes? if (!inputType.hasStaticShape()) return rewriter.notifyMatchFailure(convOp, "expected a static shape for the input"); @@ -293,6 +1027,120 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, return transformedOutput.getDefiningOp(); } +/// A helper function to decompose linalg.winograd_filter_transform. +FailureOr +decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, + linalg::WinogradFilterTransformOp op) { + Location loc = op.getLoc(); + Value filter = op.getFilter(); + auto filterType = cast(filter.getType()); + auto filterShape = filterType.getShape(); + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + Value transformedFilter = + filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedFilter) + return failure(); + + rewriter.replaceOp(op, transformedFilter); + + return transformedFilter.getDefiningOp(); +} + +/// A helper function to decompose linalg.winograd_input_transform. +FailureOr +decomposeWinogradInputTransformHelper(RewriterBase &rewriter, + linalg::WinogradInputTransformOp op) { + Location loc = op.getLoc(); + Value input = op.getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = inputH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = inputW != 1; + Value transformedInput = + inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedInput) + return failure(); + + rewriter.replaceOp(op, transformedInput); + + return transformedInput.getDefiningOp(); +} + +/// A helper function to decompose linalg.winograd_output_transform. +FailureOr +decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, + linalg::WinogradOutputTransformOp op) { + Location loc = op.getLoc(); + Value value = op.getValue(); + auto valueType = cast(value.getType()); + auto valueShape = valueType.getShape(); + int64_t valueH = valueShape[0]; + int64_t valueW = valueShape[1]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = valueH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = valueW != 1; + Value transformedOutput = + outputTransform(rewriter, loc, value, op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedOutput) + return failure(); + + rewriter.replaceOp(op, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +/// A rewrite pattern to decompose linalg.winograd_filter_transform operations. +class DecomposeWinogradFilterTransform final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, + PatternRewriter &rewriter) const override { + return decomposeWinogradFilterTransformHelper(rewriter, op); + } +}; + +/// A rewrite pattern to decompose linalg.winograd_input_transform operations. +class DecomposeWinogradInputTransform final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, + PatternRewriter &rewriter) const override { + return decomposeWinogradInputTransformHelper(rewriter, op); + } +}; + +/// A rewrite pattern to decompose linalg.winograd_output_transform operations. +class DecomposeWinogradOutputTransform final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op, + PatternRewriter &rewriter) const override { + return decomposeWinogradOutputTransformHelper(rewriter, op); + } +}; + /// A rewrite pattern for Winograd Conv2D algorithm. class WinogradConv2DNhwcFhwc final : public OpRewritePattern { @@ -316,6 +1164,12 @@ class WinogradConv2DNhwcFhwc final } // end anonymous namespace //===----------------------------------------------------------------------===// +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r) { + return winogradConv2DHelper(rewriter, op, m, r); +} + void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r) { MLIRContext *context = patterns.getContext(); @@ -323,5 +1177,12 @@ void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, patterns.insert(context, m, r); } +void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns + .insert(context); +} + } // end namespace linalg } // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir new file mode 100644 index 0000000000000..c10e0ccebfd7c --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir @@ -0,0 +1,76 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func.func @conv2d +// CHECK: linalg.winograd_filter_transform m(4) r(3) +// CHECK: linalg.winograd_input_transform m(4) r(3) +// CHECK: linalg.batch_matmul +// CHECK: linalg.winograd_output_transform m(4) r(3) + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %0 : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK: linalg.winograd_filter_transform m(4) r(3) +// CHECK: tensor.pad +// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0] +// CHECK: linalg.winograd_input_transform m(4) r(3) +// CHECK: tensor.pad +// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0] +// CHECK: linalg.winograd_output_transform m(4) r(3) + +// ----- + +func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}} + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + +func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> + return %0 : tensor<2x?x?x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+1 {{apply Winograd Conv2D failed}} + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir new file mode 100644 index 0000000000000..095a6636b68dc --- /dev/null +++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir @@ -0,0 +1,120 @@ +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s + +func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %2 = tensor.empty() : tensor<6x6x5x2xf32> + %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32> + %4 = tensor.empty() : tensor<6x6x3x3x2x5xf32> + %5 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%4 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> + %collapsed = tensor.collapse_shape %3 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> + %collapsed_0 = tensor.collapse_shape %5 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32> + %6 = tensor.empty() : tensor<36x18x2xf32> + %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32> + %expanded = tensor.expand_shape %7 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32> + %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %cst : f32 + } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32> + %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> + %extracted_slice = tensor.extract_slice %8[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> + return %extracted_slice : tensor<2x9x9x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 1.024000e+03 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32> +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32> +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32> +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32> +// CHECK-DAG: %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CST_6:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x6x5x2xf32> +// CHECK-NEXT: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) -> (tensor<6x6x5x2xf32>) { +// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x5x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], %[[C0]], %[[C0]], %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<3x3xf32> +// CHECK-NEXT: %[[S8:.*]] = tensor.empty() : tensor<6x3xf32> +// CHECK-NEXT: %[[S9:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_9]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S8]] : tensor<6x3xf32>) -> tensor<6x3xf32> +// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[S9]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S11]] into %[[ARG6]][%[[C0]], %[[C0]], %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x5x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S7]] : tensor<6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] { +// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index): +// CHECK-NEXT: tensor.yield %[[CST_6]] : f32 +// CHECK-NEXT: } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<6x6x3x3x2x5xf32>) { +// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<6x6x3x3x2x5xf32>) { +// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<6x6x3x3x2x5xf32>) { +// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<6x6x3x3x2x5xf32>) { +// CHECK-NEXT: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) +// CHECK-NEXT: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]]) +// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][%[[ARG7]], %[[S10]], %[[S11]], %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S15:.*]] = linalg.matmul ins(%[[S13]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S14]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG10]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S9]] : tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S8]] : tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S7]] : tensor<6x6x3x3x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<36x18x2xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_7]], %[[COLLAPSED]] : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%[[S4]] : tensor<36x18x2xf32>) -> tensor<36x18x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32> +// CHECK-NEXT: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0] { +// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index): +// CHECK-NEXT: tensor.yield %[[CST_6]] : f32 +// CHECK-NEXT: } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[PADDED_8]]) -> (tensor<2x12x12x2xf32>) { +// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) { +// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]]) -> (tensor<2x12x12x2xf32>) { +// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x12x12x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x3x3x2x2xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<4x6xf32> +// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32> +// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-NEXT: %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S14]] : tensor<4x4xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<4x4xf32> +// CHECK-NEXT: %[[S16:.*]] = linalg.mul ins(%[[S15]], %[[S13]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %[[S17:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) +// CHECK-NEXT: %[[S18:.*]] = affine.apply #[[$MAP0]](%[[ARG5]]) +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S16]] into %[[ARG10]][%[[ARG7]], %[[S17]], %[[S18]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S9]] : tensor<2x12x12x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S8]] : tensor<2x12x12x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S7]] : tensor<2x12x12x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S6]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 12cb46a5968f1..5899f56da7345 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -127,6 +127,9 @@ struct TestLinalgTransforms *this, "test-winograd-conv2d", llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), llvm::cl::init(false)}; + Option testDecomposeWinogradOps{ + *this, "test-decompose-winograd-ops", + llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)}; }; } // namespace @@ -218,6 +221,12 @@ static void applyWinogradConv2D(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyDecomposeWinogradOps(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateDecomposeWinogradOpsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -244,6 +253,8 @@ void TestLinalgTransforms::runOnOperation() { return applyEraseUnnecessaryInputs(getOperation()); if (testWinogradConv2D) return applyWinogradConv2D(getOperation()); + if (testDecomposeWinogradOps) + return applyDecomposeWinogradOps(getOperation()); } namespace mlir {