From e08d394084d6e8955deadbbd8f4e4157b2ea6dbc Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Mar 2021 16:16:52 +0300 Subject: [PATCH] Refactor linalg optimizations flow --- .../src/pipelines/plier_to_linalg.cpp | 76 ++++++++++--------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index c3559085f71..476f95f722b 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -34,6 +34,7 @@ #include "plier/rewrites/force_inline.hpp" #include "plier/rewrites/index_type_propagation.hpp" #include "plier/rewrites/loop_rewrites.hpp" +#include "plier/rewrites/memory_rewrites.hpp" #include "plier/transforms/loop_utils.hpp" #include "base_pipeline.hpp" @@ -44,6 +45,29 @@ namespace { +void applyOptimizations(mlir::FuncOp op, const mlir::FrozenRewritePatternList& patterns, llvm::function_ref additionalOpts = nullptr) +{ + bool repeat = false; + do + { + repeat = false; + (void)mlir::applyPatternsAndFoldGreedily(op, patterns); + if (mlir::succeeded(plier::applyCSE(op.getRegion(), false))) + { + repeat = true; + } + if (mlir::succeeded(plier::promoteLoads(op.getRegion()))) + { + repeat = true; + } + if (additionalOpts && mlir::succeeded(additionalOpts(op))) + { + repeat = true; + } + } + while(repeat); +} + enum class ArrayLayout { C, @@ -900,12 +924,12 @@ void LowerLinalgPass::runOnOperation() } struct PostPlierToLinalgPass : - public mlir::PassWrapper> + public mlir::PassWrapper { - void runOnOperation() override; + void runOnFunction() override; }; -void PostPlierToLinalgPass::runOnOperation() +void PostPlierToLinalgPass::runOnFunction() { mlir::OwningRewritePatternList patterns; @@ -916,7 +940,7 @@ void PostPlierToLinalgPass::runOnOperation() SimplifyExpandDims >(&getContext()); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyOptimizations(getFunction(), std::move(patterns)); } struct TensorFusionPass : @@ -1016,12 +1040,12 @@ void RetainArgsPass::runOnFunction() } struct PostLinalgOptPass : - public mlir::PassWrapper> + public mlir::PassWrapper { - void runOnOperation() override; + void runOnFunction() override; }; -void PostLinalgOptPass::runOnOperation() +void PostLinalgOptPass::runOnFunction() { mlir::OwningRewritePatternList patterns; @@ -1032,37 +1056,19 @@ void PostLinalgOptPass::runOnOperation() plier::CanonicalizeReduction >(&context); - mlir::FrozenRewritePatternList frozenPatterns(std::move(patterns)); - - while (true) + applyOptimizations(getFunction(), std::move(patterns), [](mlir::FuncOp op) { - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); - bool rerun = false; - for (auto& op : getOperation().getRegion().front()) - { - if (auto func = mlir::dyn_cast(op)) - { - if (mlir::succeeded(plier::naivelyFuseParallelOps(func.getRegion()))) - { - rerun = true; - } - } - } - if (!rerun) - { - break; - } - } - + return plier::naivelyFuseParallelOps(op.getRegion()); + }); } struct PromoteParallelPass : - public mlir::PassWrapper> + public mlir::PassWrapper { - void runOnOperation() override; + void runOnFunction() override; }; -void PromoteParallelPass::runOnOperation() +void PromoteParallelPass::runOnFunction() { mlir::OwningRewritePatternList patterns; @@ -1074,13 +1080,13 @@ void PromoteParallelPass::runOnOperation() plier::PromoteToParallel // TODO >(&context); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyOptimizations(getFunction(), std::move(patterns)); } void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); - pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); pm.addPass(mlir::createSymbolDCEPass()); } @@ -1105,9 +1111,9 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addPass(mlir::createCopyRemovalPass()); pm.addPass(std::make_unique()); - pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); } }