Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

[MLIR] Numpy empty and sum axis #187

Merged
merged 15 commits into from
Feb 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ bool is_int(mlir::Type type)
return type.isa<mlir::IntegerType>();
}

mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
{
if (!kwargs.empty())
{
return mlir::failure();
}
if ((operands.size() < 1 || operands.size() > 3) ||
!llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());}))
{
Expand Down Expand Up @@ -177,21 +181,21 @@ struct CallLowerer
{
mlir::LogicalResult operator()(
plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef<mlir::Value> args,
mlir::PatternRewriter& rewriter)
llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
{
using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, mlir::PatternRewriter&);
using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>>, mlir::PatternRewriter&);
std::pair<llvm::StringRef, func_t> handlers[] = {
{"numba.prange", lower_prange},
};
for (auto& handler : handlers)
{
if (handler.first == name)
{
return handler.second(op, args, rewriter);
return handler.second(op, args, kwargs, rewriter);
}
}

if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args))
if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args, kwargs))
{
assert(result->size() == op->getNumResults());
rerun_std_pipeline(op);
Expand All @@ -206,7 +210,7 @@ struct CallLowerer
return mlir::success();
}

if (name == "len" && check_numpy_args(args, 1))
if (name == "len" && check_numpy_args(args, 1) && kwargs.empty())
{
auto loc = op.getLoc();
mlir::Value dim = rewriter.create<mlir::DimOp>(loc, args[0], 0);
Expand All @@ -219,7 +223,6 @@ struct CallLowerer
}

private:

PyLinalgResolver linalg_resolver;
};

Expand Down Expand Up @@ -436,7 +439,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern<T>
mlir::OpBuilder::InsertionGuard g(rewriter);
if (auto parent_op = target.getDefiningOp())
{
rewriter.setInsertionPoint(parent_op);
rewriter.setInsertionPointAfter(parent_op);
}
else
{
Expand All @@ -456,6 +459,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern<T>
}
else
{
mlir::OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(use_op);
auto new_val = rewriter.create<mlir::TensorLoadOp>(use_op->getLoc(), memref);
rewriter.updateRootInPlace(use_op, [&]()
Expand Down Expand Up @@ -602,6 +606,7 @@ struct LowerLinalgPass :
mlir::DialectRegistry &registry) const override
{
registry.insert<mlir::StandardOpsDialect>();
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<mlir::linalg::LinalgDialect>();
registry.insert<mlir::scf::SCFDialect>();
registry.insert<mlir::AffineDialect>();
Expand Down Expand Up @@ -686,6 +691,7 @@ void PostLinalgOptPass::runOnOperation()
void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm)
{
pm.addPass(std::make_unique<PlierToLinalgPass>());
pm.addPass(mlir::createSymbolDCEPass());
}

void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
Expand All @@ -708,6 +714,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)

pm.addPass(std::make_unique<LowerLinalgPass>());
pm.addPass(std::make_unique<PostLinalgOptPass>());
pm.addPass(mlir::createSymbolDCEPass());
}
}

Expand Down
32 changes: 24 additions & 8 deletions mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1156,8 +1156,12 @@ struct FoldTupleGetitem : public mlir::OpRewritePattern<Op>
}
};

mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
{
if (!kwargs.empty())
{
return mlir::failure();
}
if ((operands.size() < 1 || operands.size() > 3) ||
!llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());}))
{
Expand Down Expand Up @@ -1191,8 +1195,12 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef<mlir::Value>
return mlir::success();
}

mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
{
if (!kwargs.empty())
{
return mlir::failure();
}
if (operands.size() != 1)
{
return mlir::failure();
Expand All @@ -1210,8 +1218,12 @@ mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> op
return mlir::success();
}

mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, mlir::PatternRewriter& rewriter)
mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef<mlir::Value> operands, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
{
if (!kwargs.empty())
{
return mlir::failure();
}
if (operands.size() != 1)
{
return mlir::failure();
Expand Down Expand Up @@ -1250,8 +1262,12 @@ mlir::FuncOp get_lib_symbol(

mlir::LogicalResult lower_math_func(
plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef<mlir::Value> args,
mlir::PatternRewriter& rewriter)
llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
{
if (!kwargs.empty())
{
return mlir::failure();
}
auto ret_type = map_plier_type(op.getType());
auto valid_type = [&](mlir::Type type)
{
Expand Down Expand Up @@ -1285,14 +1301,14 @@ mlir::LogicalResult lower_math_func(
struct CallLowerer
{
mlir::LogicalResult operator()(plier::PyCallOp op, llvm::StringRef name,
llvm::ArrayRef<mlir::Value> args, mlir::PatternRewriter& rewriter)
llvm::ArrayRef<mlir::Value> args, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
{
if (mlir::succeeded(lower_math_func(op, name, args, rewriter)))
if (mlir::succeeded(lower_math_func(op, name, args, kwargs, rewriter)))
{
return mlir::success();
}

using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, mlir::PatternRewriter&);
using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>>, mlir::PatternRewriter&);
std::pair<llvm::StringRef, func_t> handlers[] = {
{"bool", lower_bool_cast},
{"range", lower_range},
Expand All @@ -1302,7 +1318,7 @@ struct CallLowerer
{
if (handler.first == name)
{
return handler.second(op, args, rewriter);
return handler.second(op, args, kwargs, rewriter);
}
}

Expand Down
Loading