Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

[MLIR] Refactor linalg optimizations flow #210

Merged
merged 1 commit into from
Mar 20, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "plier/rewrites/force_inline.hpp"
#include "plier/rewrites/index_type_propagation.hpp"
#include "plier/rewrites/loop_rewrites.hpp"
#include "plier/rewrites/memory_rewrites.hpp"
#include "plier/transforms/loop_utils.hpp"

#include "base_pipeline.hpp"
Expand All @@ -44,6 +45,29 @@

namespace
{
void applyOptimizations(mlir::FuncOp op, const mlir::FrozenRewritePatternList& patterns, llvm::function_ref<mlir::LogicalResult(mlir::FuncOp)> additionalOpts = nullptr)
{
bool repeat = false;
do
{
repeat = false;
(void)mlir::applyPatternsAndFoldGreedily(op, patterns);
if (mlir::succeeded(plier::applyCSE(op.getRegion(), false)))
{
repeat = true;
}
if (mlir::succeeded(plier::promoteLoads(op.getRegion())))
{
repeat = true;
}
if (additionalOpts && mlir::succeeded(additionalOpts(op)))
{
repeat = true;
}
}
while(repeat);
}

enum class ArrayLayout
{
C,
Expand Down Expand Up @@ -900,12 +924,12 @@ void LowerLinalgPass::runOnOperation()
}

struct PostPlierToLinalgPass :
public mlir::PassWrapper<PostPlierToLinalgPass, mlir::OperationPass<mlir::ModuleOp>>
public mlir::PassWrapper<PostPlierToLinalgPass, mlir::FunctionPass>
{
void runOnOperation() override;
void runOnFunction() override;
};

void PostPlierToLinalgPass::runOnOperation()
void PostPlierToLinalgPass::runOnFunction()
{
mlir::OwningRewritePatternList patterns;

Expand All @@ -916,7 +940,7 @@ void PostPlierToLinalgPass::runOnOperation()
SimplifyExpandDims
>(&getContext());

(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
applyOptimizations(getFunction(), std::move(patterns));
}

struct TensorFusionPass :
Expand Down Expand Up @@ -1016,12 +1040,12 @@ void RetainArgsPass::runOnFunction()
}

struct PostLinalgOptPass :
public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp>>
public mlir::PassWrapper<PostLinalgOptPass, mlir::FunctionPass>
{
void runOnOperation() override;
void runOnFunction() override;
};

void PostLinalgOptPass::runOnOperation()
void PostLinalgOptPass::runOnFunction()
{
mlir::OwningRewritePatternList patterns;

Expand All @@ -1032,37 +1056,19 @@ void PostLinalgOptPass::runOnOperation()
plier::CanonicalizeReduction
>(&context);

mlir::FrozenRewritePatternList frozenPatterns(std::move(patterns));

while (true)
applyOptimizations(getFunction(), std::move(patterns), [](mlir::FuncOp op)
{
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
bool rerun = false;
for (auto& op : getOperation().getRegion().front())
{
if (auto func = mlir::dyn_cast<mlir::FuncOp>(op))
{
if (mlir::succeeded(plier::naivelyFuseParallelOps(func.getRegion())))
{
rerun = true;
}
}
}
if (!rerun)
{
break;
}
}

return plier::naivelyFuseParallelOps(op.getRegion());
});
}

struct PromoteParallelPass :
public mlir::PassWrapper<PromoteParallelPass, mlir::OperationPass<mlir::ModuleOp>>
public mlir::PassWrapper<PromoteParallelPass, mlir::FunctionPass>
{
void runOnOperation() override;
void runOnFunction() override;
};

void PromoteParallelPass::runOnOperation()
void PromoteParallelPass::runOnFunction()
{
mlir::OwningRewritePatternList patterns;

Expand All @@ -1074,13 +1080,13 @@ void PromoteParallelPass::runOnOperation()
plier::PromoteToParallel // TODO
>(&context);

(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
applyOptimizations(getFunction(), std::move(patterns));
}

void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm)
{
pm.addPass(std::make_unique<PlierToLinalgPass>());
pm.addPass(std::make_unique<PostPlierToLinalgPass>());
pm.addNestedPass<mlir::FuncOp>(std::make_unique<PostPlierToLinalgPass>());
pm.addPass(mlir::createSymbolDCEPass());
}

Expand All @@ -1105,9 +1111,9 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
pm.addPass(mlir::createCopyRemovalPass());

pm.addPass(std::make_unique<LowerLinalgPass>());
pm.addPass(std::make_unique<PostLinalgOptPass>());
pm.addNestedPass<mlir::FuncOp>(std::make_unique<PostLinalgOptPass>());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(std::make_unique<PromoteParallelPass>());
pm.addNestedPass<mlir::FuncOp>(std::make_unique<PromoteParallelPass>());
}
}

Expand Down