From 5981ec01d7f765dd040932526e724618c1c55c0e Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 12:55:14 +0300 Subject: [PATCH 01/15] store context in Var --- .../mlir-compiler/src/py_linalg_resolver.cpp | 27 ++++++++++--------- numba/mlir/linalg_builder.py | 3 ++- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 181c19a9166..20a979a7d92 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -125,7 +125,7 @@ struct PyLinalgResolver::Context py::handle compile_func; py::handle lookup_func; - py::object create_var(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value value) + py::object create_var(mlir::Location loc, mlir::OpBuilder& builder, py::capsule context, mlir::Value value) { if (value.getType().isa()) { @@ -142,9 +142,9 @@ struct PyLinalgResolver::Context mlir::Value mlir_dim = builder.create(loc, value, it2.index()); py_shape[it2.index()] = make_dim_val(it2.value(), mlir_dim); } - return var(wrap_mlir(value), py_shape, wrap_mlir(elem_type)); + return var(context, wrap_mlir(value), py_shape, wrap_mlir(elem_type)); } - return var(wrap_mlir(value), py::list(), wrap_mlir(value.getType())); + return var(context, wrap_mlir(value), py::list(), wrap_mlir(value.getType())); } mlir::FuncOp compile_body(py::handle body, py::list arg_types) @@ -156,7 +156,7 @@ struct PyLinalgResolver::Context return mlir_func; } - py::object wrap_result(mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange values) + py::object wrap_result(mlir::Location loc, mlir::OpBuilder& builder, py::capsule context, mlir::ValueRange values) { if (values.empty()) { @@ -164,12 +164,12 @@ struct PyLinalgResolver::Context } if (values.size() == 1) { - return create_var(loc, builder, values.front()); + return create_var(loc, builder, context, values.front()); } py::tuple ret(values.size()); for (auto it : llvm::enumerate(values)) { - ret[it.index()] = create_var(loc, builder, it.value()); + ret[it.index()] = create_var(loc, builder, context, it.value()); } return std::move(ret); } @@ -296,7 +296,7 @@ py::object init_tensor_impl(py::capsule context, py::list shape, py::capsule dty { init = ctx.builder.create(ctx.loc, unwrap_shape(shape), elem_type); } - return ctx.context.create_var(ctx.loc, ctx.builder, init); + return ctx.context.create_var(ctx.loc, ctx.builder, context, init); } py::object generic_impl(py::capsule context, py::handle inputs, py::handle outputs, py::list iterators, py::list maps, py::handle body) @@ -337,7 +337,7 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu { inputs_args.append(output_args.begin(), output_args.end()); auto res = builder.create(loc, body_func, inputs_args); - return ctx.context.wrap_result(loc, builder, cast_values(res.getResults(), ret_types)); + return ctx.context.wrap_result(loc, builder, context, cast_values(res.getResults(), ret_types)); } else { @@ -359,7 +359,7 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu affine_maps, mlir_iterators, body_builder); - return ctx.context.wrap_result(loc, builder, generic_op.getResults()); + return ctx.context.wrap_result(loc, builder, context, generic_op.getResults()); } } @@ -400,7 +400,7 @@ py::object from_elements_impl(py::capsule context, py::handle values, py::capsul } }); auto res = builder.create(loc, vals); - return ctx.context.create_var(ctx.loc, ctx.builder, res); + return ctx.context.create_var(ctx.loc, ctx.builder, context, res); } py::object extract_impl(py::capsule context, py::handle value, py::handle indices) @@ -426,7 +426,7 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice } }); auto res = builder.create(loc, get_var_value(value), ind); - return ctx.context.create_var(ctx.loc, ctx.builder, res); + return ctx.context.create_var(ctx.loc, ctx.builder, context, res); } void setup_py_builder(py::handle builder, mlir::OpBuilder& b) @@ -507,7 +507,8 @@ llvm::Optional PyLinalgResolver::rewrite(llvm::StringR } PyBuilderContext py_builder_context{loc, builder, *context}; - auto py_builder = context->builder(py::capsule(&py_builder_context)); + auto py_context = py::capsule(&py_builder_context); + auto py_builder = context->builder(py_context); setup_py_builder(py_builder, builder); assert(!args.empty()); @@ -519,7 +520,7 @@ llvm::Optional PyLinalgResolver::rewrite(llvm::StringR { auto index = static_cast(it.index()); auto mlir_arg = it.value(); - py_args[index] = context->create_var(loc, builder, mlir_arg); + py_args[index] = context->create_var(loc, builder, py_context, mlir_arg); } auto result = builder_func(py_builder, *py_args); diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 0d3c0bb8bfa..0e840dd39e1 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -1,7 +1,8 @@ from .func_registry import add_func class Var: - def __init__(self, ssa_val, shape, dtype): + def __init__(self, context, ssa_val, shape, dtype): + self._context = context self._ssa_val = ssa_val self._shape = shape self._dtype = dtype From d53b90fb20f2a6342cefda2b6da2508bebfd6520 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 13:22:33 +0300 Subject: [PATCH 02/15] rework shape accessor --- .../mlir-compiler/src/py_linalg_resolver.cpp | 37 ++++++++++++++++++- numba/mlir/linalg_builder.py | 2 +- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 20a979a7d92..17a1bb79d6a 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -113,6 +113,8 @@ void container_iterate(py::handle obj, F&& func) func(std::size_t(0), obj); } } + +py::object setup_py_var(py::object var); } struct PyLinalgResolver::Context @@ -142,9 +144,9 @@ struct PyLinalgResolver::Context mlir::Value mlir_dim = builder.create(loc, value, it2.index()); py_shape[it2.index()] = make_dim_val(it2.value(), mlir_dim); } - return var(context, wrap_mlir(value), py_shape, wrap_mlir(elem_type)); + return setup_py_var(var(context, wrap_mlir(value), py_shape, wrap_mlir(elem_type))); } - return var(context, wrap_mlir(value), py::list(), wrap_mlir(value.getType())); + return setup_py_var(var(context, wrap_mlir(value), py::list(), wrap_mlir(value.getType()))); } mlir::FuncOp compile_body(py::handle body, py::list arg_types) @@ -452,6 +454,37 @@ void setup_py_builder(py::handle builder, mlir::OpBuilder& b) add_type("float64", b.getF64Type()); } +py::object shape_impl(py::capsule context, py::capsule ssa_val) +{ + auto& ctx = get_py_context(context); + auto value = unwrap_mlir(ssa_val); + if (value.getType().isa()) + { + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto make_dim_val = [&](auto dim, auto ssa_val) + { + return ctx.context.val(get_dim(dim), wrap_mlir(ssa_val)); + }; + auto mlir_type = value.getType().cast(); + auto shape = mlir_type.getShape(); + py::list py_shape(shape.size()); + for (auto it2 : llvm::enumerate(shape)) + { + mlir::Value mlir_dim = builder.create(loc, value, it2.index()); + py_shape[it2.index()] = make_dim_val(it2.value(), mlir_dim); + } + return std::move(py_shape); + } + return py::list(); +} + +py::object setup_py_var(py::object var) +{ + py::setattr(var, "_shape", py::cpp_function(&shape_impl)); + return var; +} + PyLinalgResolver::Values unpack_results(py::handle object) { PyLinalgResolver::Values ret; diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 0e840dd39e1..cbfd1a3e737 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -9,7 +9,7 @@ def __init__(self, context, ssa_val, shape, dtype): @property def shape(self): - return self._shape + return self._shape(self._context, self._ssa_val) @property def dtype(self): From 71b35c474c85b48fcf59ec7cf6183355598982e7 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 13:27:03 +0300 Subject: [PATCH 03/15] dtype accessor rework --- .../mlir-compiler/src/py_linalg_resolver.cpp | 12 ++++++++++++ numba/mlir/linalg_builder.py | 6 +++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 17a1bb79d6a..b3aed3a0f21 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -479,9 +479,21 @@ py::object shape_impl(py::capsule context, py::capsule ssa_val) return py::list(); } +py::object dtype_impl(py::capsule /*context*/, py::capsule ssa_val) +{ + auto value = unwrap_mlir(ssa_val); + auto type = value.getType(); + if (auto tensor_type = type.dyn_cast()) + { + return wrap_mlir(tensor_type.getElementType()); + } + return wrap_mlir(type); +} + py::object setup_py_var(py::object var) { py::setattr(var, "_shape", py::cpp_function(&shape_impl)); + py::setattr(var, "_dtype", py::cpp_function(&dtype_impl)); return var; } diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index cbfd1a3e737..9ef08aa7f4c 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -4,8 +4,8 @@ class Var: def __init__(self, context, ssa_val, shape, dtype): self._context = context self._ssa_val = ssa_val - self._shape = shape - self._dtype = dtype + # self._shape = shape + # self._dtype = dtype @property def shape(self): @@ -13,7 +13,7 @@ def shape(self): @property def dtype(self): - return self._dtype + return self._dtype(self._context, self._ssa_val) From 437ac110c7ac7cbf10d867280d94a61071035efa Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 13:31:50 +0300 Subject: [PATCH 04/15] remove unused code --- .../mlir-compiler/src/py_linalg_resolver.cpp | 46 ++++++------------- numba/mlir/linalg_builder.py | 4 +- 2 files changed, 16 insertions(+), 34 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index b3aed3a0f21..af93baf3e92 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -114,7 +114,7 @@ void container_iterate(py::handle obj, F&& func) } } -py::object setup_py_var(py::object var); +void setup_py_var(py::handle var); } struct PyLinalgResolver::Context @@ -127,26 +127,11 @@ struct PyLinalgResolver::Context py::handle compile_func; py::handle lookup_func; - py::object create_var(mlir::Location loc, mlir::OpBuilder& builder, py::capsule context, mlir::Value value) + py::object create_var(py::capsule context, mlir::Value value) { - if (value.getType().isa()) - { - auto make_dim_val = [&](auto dim, auto ssa_val) - { - return val(get_dim(dim), wrap_mlir(ssa_val)); - }; - auto mlir_type = value.getType().cast(); - auto shape = mlir_type.getShape(); - auto elem_type = mlir_type.getElementType(); - py::list py_shape(shape.size()); - for (auto it2 : llvm::enumerate(shape)) - { - mlir::Value mlir_dim = builder.create(loc, value, it2.index()); - py_shape[it2.index()] = make_dim_val(it2.value(), mlir_dim); - } - return setup_py_var(var(context, wrap_mlir(value), py_shape, wrap_mlir(elem_type))); - } - return setup_py_var(var(context, wrap_mlir(value), py::list(), wrap_mlir(value.getType()))); + auto ret = var(context, wrap_mlir(value)); + setup_py_var(ret); + return ret; } mlir::FuncOp compile_body(py::handle body, py::list arg_types) @@ -158,7 +143,7 @@ struct PyLinalgResolver::Context return mlir_func; } - py::object wrap_result(mlir::Location loc, mlir::OpBuilder& builder, py::capsule context, mlir::ValueRange values) + py::object wrap_result(py::capsule context, mlir::ValueRange values) { if (values.empty()) { @@ -166,12 +151,12 @@ struct PyLinalgResolver::Context } if (values.size() == 1) { - return create_var(loc, builder, context, values.front()); + return create_var(context, values.front()); } py::tuple ret(values.size()); for (auto it : llvm::enumerate(values)) { - ret[it.index()] = create_var(loc, builder, context, it.value()); + ret[it.index()] = create_var(context, it.value()); } return std::move(ret); } @@ -298,7 +283,7 @@ py::object init_tensor_impl(py::capsule context, py::list shape, py::capsule dty { init = ctx.builder.create(ctx.loc, unwrap_shape(shape), elem_type); } - return ctx.context.create_var(ctx.loc, ctx.builder, context, init); + return ctx.context.create_var(context, init); } py::object generic_impl(py::capsule context, py::handle inputs, py::handle outputs, py::list iterators, py::list maps, py::handle body) @@ -339,7 +324,7 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu { inputs_args.append(output_args.begin(), output_args.end()); auto res = builder.create(loc, body_func, inputs_args); - return ctx.context.wrap_result(loc, builder, context, cast_values(res.getResults(), ret_types)); + return ctx.context.wrap_result(context, cast_values(res.getResults(), ret_types)); } else { @@ -361,7 +346,7 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu affine_maps, mlir_iterators, body_builder); - return ctx.context.wrap_result(loc, builder, context, generic_op.getResults()); + return ctx.context.wrap_result(context, generic_op.getResults()); } } @@ -402,7 +387,7 @@ py::object from_elements_impl(py::capsule context, py::handle values, py::capsul } }); auto res = builder.create(loc, vals); - return ctx.context.create_var(ctx.loc, ctx.builder, context, res); + return ctx.context.create_var(context, res); } py::object extract_impl(py::capsule context, py::handle value, py::handle indices) @@ -428,7 +413,7 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice } }); auto res = builder.create(loc, get_var_value(value), ind); - return ctx.context.create_var(ctx.loc, ctx.builder, context, res); + return ctx.context.create_var(context, res); } void setup_py_builder(py::handle builder, mlir::OpBuilder& b) @@ -490,11 +475,10 @@ py::object dtype_impl(py::capsule /*context*/, py::capsule ssa_val) return wrap_mlir(type); } -py::object setup_py_var(py::object var) +void setup_py_var(pybind11::handle var) { py::setattr(var, "_shape", py::cpp_function(&shape_impl)); py::setattr(var, "_dtype", py::cpp_function(&dtype_impl)); - return var; } PyLinalgResolver::Values unpack_results(py::handle object) @@ -565,7 +549,7 @@ llvm::Optional PyLinalgResolver::rewrite(llvm::StringR { auto index = static_cast(it.index()); auto mlir_arg = it.value(); - py_args[index] = context->create_var(loc, builder, py_context, mlir_arg); + py_args[index] = context->create_var(py_context, mlir_arg); } auto result = builder_func(py_builder, *py_args); diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 9ef08aa7f4c..2a46e3812b6 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -1,11 +1,9 @@ from .func_registry import add_func class Var: - def __init__(self, context, ssa_val, shape, dtype): + def __init__(self, context, ssa_val): self._context = context self._ssa_val = ssa_val - # self._shape = shape - # self._dtype = dtype @property def shape(self): From 24fb1b5d411419c92bfa9bc6d57aaa0f2a7f9afc Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 14:24:33 +0300 Subject: [PATCH 05/15] refactor shape --- .../mlir-compiler/src/py_linalg_resolver.cpp | 52 +++++++++++++------ numba/mlir/linalg_builder.py | 3 ++ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index af93baf3e92..8337f75c88b 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -66,13 +66,16 @@ auto unwrap_ssa_val(py::handle obj) return unwrap_mlir(obj.attr("_ssa_val").cast()); } -auto unwrap_shape(py::list shape) +auto unwrap_shape(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value shape) { - llvm::SmallVector ret; - ret.reserve(shape.size()); - for (auto elem : shape) + auto type = shape.getType().cast(); + llvm::SmallVector ret(type.size()); + for (auto it : llvm::enumerate(ret)) { - ret.push_back(unwrap_ssa_val(elem)); + auto elem_type = type.getType(it.index()); + auto index = builder.create(loc, it.index()); + auto dim = builder.create(loc, elem_type, shape, index); + it.value() = dim; } return ret; } @@ -267,12 +270,12 @@ py::object broadcast_impl(py::capsule /*context*/, py::tuple args) } } -py::object init_tensor_impl(py::capsule context, py::list shape, py::capsule dtype) +py::object init_tensor_impl(py::capsule context, py::handle shape, py::capsule dtype) { auto& ctx = get_py_context(context); auto elem_type = unwrap_mlir(dtype); mlir::Value init; - if (shape.empty()) + if (py::len(shape) == 0) { // TODO: undef auto zero_val = plier::getZeroVal(elem_type); @@ -281,7 +284,10 @@ py::object init_tensor_impl(py::capsule context, py::list shape, py::capsule dty } else { - init = ctx.builder.create(ctx.loc, unwrap_shape(shape), elem_type); + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto shape_val = unwrap_shape(loc, builder, unwrap_ssa_val(shape)) ; + init = ctx.builder.create(ctx.loc, shape_val, elem_type); } return ctx.context.create_var(context, init); } @@ -447,19 +453,19 @@ py::object shape_impl(py::capsule context, py::capsule ssa_val) { auto& builder = ctx.builder; auto loc = ctx.loc; - auto make_dim_val = [&](auto dim, auto ssa_val) - { - return ctx.context.val(get_dim(dim), wrap_mlir(ssa_val)); - }; auto mlir_type = value.getType().cast(); auto shape = mlir_type.getShape(); - py::list py_shape(shape.size()); - for (auto it2 : llvm::enumerate(shape)) + llvm::SmallVector shape_vals(shape.size()); + for (auto it : llvm::enumerate(shape)) { - mlir::Value mlir_dim = builder.create(loc, value, it2.index()); - py_shape[it2.index()] = make_dim_val(it2.value(), mlir_dim); + auto i = it.index(); + mlir::Value mlir_dim = builder.create(loc, value, i); + shape_vals[i] = mlir_dim; } - return std::move(py_shape); + llvm::SmallVector shape_types(shape.size(), builder.getIndexType()); + auto shape_type = mlir::TupleType::get(builder.getContext(), shape_types); + auto shape_var = builder.create(loc, shape_type, shape_vals); + return ctx.context.create_var(context, shape_var.getResult()); } return py::list(); } @@ -475,10 +481,22 @@ py::object dtype_impl(py::capsule /*context*/, py::capsule ssa_val) return wrap_mlir(type); } +py::object len_impl(py::capsule /*context*/, py::capsule ssa_val) +{ + auto value = unwrap_mlir(ssa_val); + auto type = value.getType(); + if (auto tuple_type = type.dyn_cast()) + { + return py::int_(tuple_type.size()); + } + return py::int_(0); +} + void setup_py_var(pybind11::handle var) { py::setattr(var, "_shape", py::cpp_function(&shape_impl)); py::setattr(var, "_dtype", py::cpp_function(&dtype_impl)); + py::setattr(var, "_len", py::cpp_function(&len_impl)); } PyLinalgResolver::Values unpack_results(py::handle object) diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 2a46e3812b6..088c23aac44 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -13,6 +13,9 @@ def shape(self): def dtype(self): return self._dtype(self._context, self._ssa_val) + def __len__(self): + return self._len(self._context, self._ssa_val) + class Val: From ba50774373ad20024228dfdaf1678679d9e25f8f Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 14:26:41 +0300 Subject: [PATCH 06/15] remove unused code --- .../mlir-compiler/src/py_linalg_resolver.cpp | 11 ----------- numba/mlir/linalg_builder.py | 10 ---------- 2 files changed, 21 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 8337f75c88b..190d62c851e 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -35,15 +35,6 @@ bool is_compatible_types(mlir::TypeRange types) }); } -py::handle get_dim(int64_t val) -{ - if (val == -1) - { - return py::none(); - } - return py::int_(val); -} - size_t py_func_arg_count(py::handle signature, py::handle func) { return py::len(signature(func).attr("parameters")); @@ -123,7 +114,6 @@ void setup_py_var(py::handle var); struct PyLinalgResolver::Context { py::handle var; - py::handle val; py::handle builder; py::handle signature; py::handle types_mod; @@ -526,7 +516,6 @@ PyLinalgResolver::PyLinalgResolver(): { auto builder_mod = py::module::import("numba.mlir.linalg_builder"); context->var = builder_mod.attr("Var"); - context->val = builder_mod.attr("Val"); context->builder = builder_mod.attr("Builder"); context->signature = py::module::import("inspect").attr("signature"); context->types_mod = py::module::import("numba.core.types"); diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 088c23aac44..1dbcb295607 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -16,16 +16,6 @@ def dtype(self): def __len__(self): return self._len(self._context, self._ssa_val) - - -class Val: - def __init__(self, const_val, ssa_val): - self._const_val = const_val - self._ssa_val = ssa_val - - def is_const(self): - return not _const_val is None - class Builder: def __init__(self, context): self._context = context From 42a0964733c2cf99a761cfa1e75bde8561926898 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:23 +0300 Subject: [PATCH 07/15] fix setitem lowering --- mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 38d62bd412e..4fc62e58b5b 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -436,7 +436,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern mlir::OpBuilder::InsertionGuard g(rewriter); if (auto parent_op = target.getDefiningOp()) { - rewriter.setInsertionPoint(parent_op); + rewriter.setInsertionPointAfter(parent_op); } else { @@ -456,6 +456,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern } else { + mlir::OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(use_op); auto new_val = rewriter.create(use_op->getLoc(), memref); rewriter.updateRootInPlace(use_op, [&]() From 48b275870ae6ae7b17d675f932e1cea4ef1e0494 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:41 +0300 Subject: [PATCH 08/15] numpy empty --- .../mlir-compiler/src/py_linalg_resolver.cpp | 45 ++++++++++++++----- numba/mlir/numpy/funcs.py | 7 ++- numba/mlir/tests/test_numpy.py | 21 +++++++++ 3 files changed, 60 insertions(+), 13 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 190d62c851e..17f4890d6fc 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -27,12 +27,17 @@ struct PyBuilderContext namespace { -bool is_compatible_types(mlir::TypeRange types) +bool is_compatible_type(mlir::Type type) { - return !types.empty() && llvm::all_of(types, [](mlir::Type t) + if (auto tuple_type = type.dyn_cast()) { - return t.isIntOrFloat() || t.isa(); - }); + return llvm::all_of(tuple_type, &is_compatible_type); + } + return type.isIntOrFloat() || type.isa(); +} +bool is_compatible_types(mlir::TypeRange types) +{ + return !types.empty() && llvm::all_of(types, &is_compatible_type); } size_t py_func_arg_count(py::handle signature, py::handle func) @@ -59,14 +64,30 @@ auto unwrap_ssa_val(py::handle obj) auto unwrap_shape(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value shape) { - auto type = shape.getType().cast(); - llvm::SmallVector ret(type.size()); - for (auto it : llvm::enumerate(ret)) + llvm::SmallVector ret; + auto type = shape.getType(); + auto index_cast = [&](mlir::Value val)->mlir::Value + { + if (!val.getType().isa()) + { + return builder.create(loc, builder.getIndexType(), val); + } + return val; + }; + if (auto tuple_type = type.dyn_cast()) + { + ret.resize(tuple_type.size()); + for (auto it : llvm::enumerate(ret)) + { + auto elem_type = tuple_type.getType(it.index()); + auto index = builder.create(loc, it.index()); + auto dim = builder.create(loc, elem_type, shape, index); + it.value() = index_cast(dim); + } + } + else { - auto elem_type = type.getType(it.index()); - auto index = builder.create(loc, it.index()); - auto dim = builder.create(loc, elem_type, shape, index); - it.value() = dim; + ret.emplace_back(index_cast(shape)); } return ret; } @@ -479,7 +500,7 @@ py::object len_impl(py::capsule /*context*/, py::capsule ssa_val) { return py::int_(tuple_type.size()); } - return py::int_(0); + return py::int_(1); } void setup_py_var(pybind11::handle var) diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index f6ef37d15e4..0167cbc04e5 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -57,9 +57,14 @@ def body(a, b): return eltwise(builder, arg, body, builder.float64) @register_func('numpy.square', numpy.square) -def quare_impl(builder, arg): +def square_impl(builder, arg): def body(a, b): return a * a return eltwise(builder, arg, body) + +@register_func('numpy.empty', numpy.empty) +def empty_impl(builder, shape): + # TODO: dtype + return builder.init_tensor(shape, builder.float64) diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 1dbcb977317..bce281c07ae 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -209,5 +209,26 @@ def py_func(a, b): arr = np.array([0.0]) assert_equal(py_func(arr, 5), jit_func(arr, 5)) + def test_empty1(self): + def py_func(d): + a = np.empty(d) + for i in range(d): + a[i] = i + return a + + jit_func = njit(py_func) + assert_equal(py_func(5), jit_func(5)) + + def test_empty2(self): + def py_func(d1, d2): + a = np.empty((d1, d2)) + for i in range(d1): + for j in range(d2): + a[i, j] = i + j * 10 + return a + + jit_func = njit(py_func) + assert_equal(py_func(5, 7), jit_func(5, 7)) + if __name__ == '__main__': unittest.main() From b3a2105352d1a2e087ba80690d6e3e2dcdaf941a Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:41 +0300 Subject: [PATCH 09/15] numpy.sum --- numba/mlir/numpy/funcs.py | 1 + numba/mlir/tests/test_numpy.py | 1 + 2 files changed, 2 insertions(+) diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index 0167cbc04e5..42e682643be 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -31,6 +31,7 @@ def body(a, b, c): return eltwise(builder, (arg1, arg2), body) @register_func('array.sum') +@register_func('numpy.sum', numpy.sum) def sum_impl(builder, arg): shape = arg.shape diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index bce281c07ae..a286e68c72d 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -40,6 +40,7 @@ def py_func(a): def test_unary(self): funcs = [ lambda a: a.sum(), + lambda a: np.sum(a), lambda a: np.sqrt(a), lambda a: np.square(a), ] From 1bec4ad8b43a9855060d914ab9acde8d0a9ca7de Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:41 +0300 Subject: [PATCH 10/15] some kwargs support --- .../src/pipelines/plier_to_linalg.cpp | 17 ++++++---- .../src/pipelines/plier_to_std.cpp | 32 ++++++++++++++----- .../mlir-compiler/src/py_linalg_resolver.cpp | 2 +- .../mlir-compiler/src/py_linalg_resolver.hpp | 4 ++- .../include/plier/rewrites/call_lowering.hpp | 2 +- .../plier/src/rewrites/call_lowering.cpp | 23 +++++++------ 6 files changed, 50 insertions(+), 30 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 4fc62e58b5b..3ff37c6b90d 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -133,8 +133,12 @@ bool is_int(mlir::Type type) return type.isa(); } -mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } if ((operands.size() < 1 || operands.size() > 3) || !llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());})) { @@ -177,9 +181,9 @@ struct CallLowerer { mlir::LogicalResult operator()( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, - mlir::PatternRewriter& rewriter) + llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { - using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, llvm::ArrayRef>, mlir::PatternRewriter&); std::pair handlers[] = { {"numba.prange", lower_prange}, }; @@ -187,11 +191,11 @@ struct CallLowerer { if (handler.first == name) { - return handler.second(op, args, rewriter); + return handler.second(op, args, kwargs, rewriter); } } - if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args)) + if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args, kwargs)) { assert(result->size() == op->getNumResults()); rerun_std_pipeline(op); @@ -206,7 +210,7 @@ struct CallLowerer return mlir::success(); } - if (name == "len" && check_numpy_args(args, 1)) + if (name == "len" && check_numpy_args(args, 1) && kwargs.empty()) { auto loc = op.getLoc(); mlir::Value dim = rewriter.create(loc, args[0], 0); @@ -219,7 +223,6 @@ struct CallLowerer } private: - PyLinalgResolver linalg_resolver; }; diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index 1718da5fc94..6f22320bd18 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1156,8 +1156,12 @@ struct FoldTupleGetitem : public mlir::OpRewritePattern } }; -mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } if ((operands.size() < 1 || operands.size() > 3) || !llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());})) { @@ -1191,8 +1195,12 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef return mlir::success(); } -mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } if (operands.size() != 1) { return mlir::failure(); @@ -1210,8 +1218,12 @@ mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef op return mlir::success(); } -mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } if (operands.size() != 1) { return mlir::failure(); @@ -1250,8 +1262,12 @@ mlir::FuncOp get_lib_symbol( mlir::LogicalResult lower_math_func( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, - mlir::PatternRewriter& rewriter) + llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } auto ret_type = map_plier_type(op.getType()); auto valid_type = [&](mlir::Type type) { @@ -1285,14 +1301,14 @@ mlir::LogicalResult lower_math_func( struct CallLowerer { mlir::LogicalResult operator()(plier::PyCallOp op, llvm::StringRef name, - llvm::ArrayRef args, mlir::PatternRewriter& rewriter) + llvm::ArrayRef args, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { - if (mlir::succeeded(lower_math_func(op, name, args, rewriter))) + if (mlir::succeeded(lower_math_func(op, name, args, kwargs, rewriter))) { return mlir::success(); } - using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, llvm::ArrayRef>, mlir::PatternRewriter&); std::pair handlers[] = { {"bool", lower_bool_cast}, {"range", lower_range}, @@ -1302,7 +1318,7 @@ struct CallLowerer { if (handler.first == name) { - return handler.second(op, args, rewriter); + return handler.second(op, args, kwargs, rewriter); } } diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 17f4890d6fc..090c6b5c154 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -549,7 +549,7 @@ PyLinalgResolver::~PyLinalgResolver() } -llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args) +llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, llvm::ArrayRef > kwargs) { assert(!name.empty()); if (!is_compatible_types(args.getTypes())) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp index 80ca93da3d0..66769084b50 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp @@ -4,6 +4,7 @@ #include #include +#include namespace llvm { @@ -27,7 +28,8 @@ class PyLinalgResolver using Values = llvm::SmallVector; - llvm::Optional rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args); + llvm::Optional rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, + llvm::ArrayRef> kwargs); private: friend struct PyBuilderContext; diff --git a/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp index c686004d5e1..48d58f9958d 100644 --- a/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp @@ -15,7 +15,7 @@ namespace plier { struct CallOpLowering : public mlir::OpRewritePattern { - using resolver_t = llvm::function_ref, mlir::PatternRewriter&)>; + using resolver_t = llvm::function_ref, llvm::ArrayRef> , mlir::PatternRewriter&)>; CallOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, diff --git a/mlir-compiler/plier/src/rewrites/call_lowering.cpp b/mlir-compiler/plier/src/rewrites/call_lowering.cpp index 53e93c30d60..6de918fa317 100644 --- a/mlir-compiler/plier/src/rewrites/call_lowering.cpp +++ b/mlir-compiler/plier/src/rewrites/call_lowering.cpp @@ -18,23 +18,22 @@ mlir::LogicalResult plier::CallOpLowering::matchAndRewrite(plier::PyCallOp op, m return mlir::failure(); } - llvm::SmallVector arg_types; llvm::SmallVector args; + llvm::SmallVector, 8> kwargs; auto getattr = mlir::dyn_cast_or_null(operands[0].getDefiningOp()); - if (!getattr) + if (getattr) { - llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); - llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); - // TODO kwargs + args.push_back(getattr.getOperand()); } - else + auto kw_start = op.kw_start(); + operands = operands.drop_front(); + llvm::copy(operands.take_front(kw_start), std::back_inserter(args)); + for (auto it : llvm::zip(operands.drop_front(kw_start), op.kw_names())) { - arg_types.push_back(getattr.getOperand().getType()); - args.push_back(getattr.getOperand()); - llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); - llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); - // TODO kwargs + auto arg = std::get<0>(it); + auto name = std::get<1>(it).cast(); + kwargs.emplace_back(name.getValue(), arg); } - return resolver(op, op.func_name(), args, rewriter); + return resolver(op, op.func_name(), args, kwargs, rewriter); } From 571227846a5415b66361ed8d2692894e8ae77472 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:41 +0300 Subject: [PATCH 11/15] linlag resolver kwargs support --- .../mlir-compiler/src/py_linalg_resolver.cpp | 96 ++++++++++++++----- 1 file changed, 73 insertions(+), 23 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 090c6b5c154..0b94bea1acd 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -35,14 +35,11 @@ bool is_compatible_type(mlir::Type type) } return type.isIntOrFloat() || type.isa(); } -bool is_compatible_types(mlir::TypeRange types) -{ - return !types.empty() && llvm::all_of(types, &is_compatible_type); -} -size_t py_func_arg_count(py::handle signature, py::handle func) +template +bool is_compatible_types(R&& vals) { - return py::len(signature(func).attr("parameters")); + return llvm::all_of(vals, [](auto val) { return is_compatible_type(val.getType()); }); } template @@ -136,7 +133,7 @@ struct PyLinalgResolver::Context { py::handle var; py::handle builder; - py::handle signature; + py::handle inspect; py::handle types_mod; py::handle compile_func; py::handle lookup_func; @@ -178,6 +175,60 @@ struct PyLinalgResolver::Context namespace { +py::list get_args(py::handle inspect, py::handle func, llvm::function_ref create_var, + mlir::ValueRange args, llvm::ArrayRef> kwargs) +{ + auto sig_func = inspect.attr("signature"); + auto sig = sig_func(func); + auto params = sig.attr("parameters"); + auto params_list = py::list(params); + params_list = params_list[py::slice(1, static_cast(params_list.size()), 1)]; // skip builder param + auto empty = inspect.attr("Parameter").attr("empty"); + + py::list ret(py::len(params_list)); + for (auto it : llvm::enumerate(params_list)) + { + auto index = it.index(); + auto param_name = it.value(); + auto param = params[param_name]; + if (!args.empty()) + { + ret[index] = create_var(args.front()); + args = args.drop_front(); + continue; + } + if (!kwargs.empty()) + { + auto name = param_name.cast(); + auto val = [&]()->mlir::Value + { + for (auto kwarg : kwargs) + { + if (kwarg.first == name) + { + return kwarg.second; + } + } + return {}; + }(); + if (val) + { + ret[index] = create_var(val); + continue; + } + } + auto def_val = param.attr("default"); + if (!def_val.is(empty)) + { + ret[index] = def_val; + } + else + { + return py::none(); + } + } + return ret; +} PyBuilderContext& get_py_context(py::capsule& ctx) { @@ -538,7 +589,7 @@ PyLinalgResolver::PyLinalgResolver(): auto builder_mod = py::module::import("numba.mlir.linalg_builder"); context->var = builder_mod.attr("Var"); context->builder = builder_mod.attr("Builder"); - context->signature = py::module::import("inspect").attr("signature"); + context->inspect = py::module::import("inspect"); context->types_mod = py::module::import("numba.core.types"); context->compile_func = builder_mod.attr("compile_func"); context->lookup_func = builder_mod.attr("lookup_func"); @@ -549,36 +600,35 @@ PyLinalgResolver::~PyLinalgResolver() } -llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, llvm::ArrayRef > kwargs) +llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, llvm::ArrayRef> kwargs) { assert(!name.empty()); - if (!is_compatible_types(args.getTypes())) + if (!is_compatible_types(args) || + !is_compatible_types(llvm::make_second_range(kwargs))) { return {}; } auto builder_func = context->lookup_func(py::str(name.data(), name.size())); - if (builder_func.is_none() || py_func_arg_count(context->signature, builder_func) != (args.size() + 1)) + if (builder_func.is_none()) { return {}; } PyBuilderContext py_builder_context{loc, builder, *context}; auto py_context = py::capsule(&py_builder_context); - auto py_builder = context->builder(py_context); - setup_py_builder(py_builder, builder); - - assert(!args.empty()); - auto module = args.front().getParentRegion()->getParentOfType(); - assert(module); - - py::list py_args(args.size()); - for (auto it : llvm::enumerate(args)) + auto py_args = get_args( + context->inspect, + builder_func, + [&](auto val){ return context->create_var(py_context, val);}, + args, + kwargs); + if (py_args.is_none()) { - auto index = static_cast(it.index()); - auto mlir_arg = it.value(); - py_args[index] = context->create_var(py_context, mlir_arg); + return {}; } + auto py_builder = context->builder(py_context); + setup_py_builder(py_builder, builder); auto result = builder_func(py_builder, *py_args); return unpack_results(result); From 75dfc8b0cfad174c5ee9247072ce342e81edd3a9 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:41 +0300 Subject: [PATCH 12/15] linalg resolver some literal support --- .../mlir-compiler/src/py_linalg_resolver.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 0b94bea1acd..71995dc1d73 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -126,6 +126,19 @@ void container_iterate(py::handle obj, F&& func) } } +llvm::Optional make_py_literal(mlir::Value val) +{ + if (auto int_val = plier::getConstVal(val)) + { + return py::int_(int_val.getInt()); + } + if (auto float_val = plier::getConstVal(val)) + { + return py::float_(float_val.getValueAsDouble()); + } + return {}; +} + void setup_py_var(py::handle var); } @@ -140,6 +153,10 @@ struct PyLinalgResolver::Context py::object create_var(py::capsule context, mlir::Value value) { + if (auto literal = make_py_literal(value)) + { + return *literal; + } auto ret = var(context, wrap_mlir(value)); setup_py_var(ret); return ret; From 8557f0fb74de07ea3df87bc6251f47e6d6976048 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:41 +0300 Subject: [PATCH 13/15] work on linalg resolver --- .../mlir-compiler/src/py_linalg_resolver.cpp | 45 +++++++++++++++++-- numba/mlir/linalg_builder.py | 6 +++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 71995dc1d73..2bcd2d9f514 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -354,7 +354,8 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::capsule d auto& ctx = get_py_context(context); auto elem_type = unwrap_mlir(dtype); mlir::Value init; - if (py::len(shape) == 0) + auto count = py::len(shape); + if (count == 0) { // TODO: undef auto zero_val = plier::getZeroVal(elem_type); @@ -363,9 +364,11 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::capsule d } else { - auto& builder = ctx.builder; - auto loc = ctx.loc; - auto shape_val = unwrap_shape(loc, builder, unwrap_ssa_val(shape)) ; + llvm::SmallVector shape_val(count); + for (size_t i = 0; i < count; ++i) + { + shape_val[i] = unwrap_ssa_val(shape[py::int_(i)]); + } init = ctx.builder.create(ctx.loc, shape_val, elem_type); } return ctx.context.create_var(context, init); @@ -571,11 +574,41 @@ py::object len_impl(py::capsule /*context*/, py::capsule ssa_val) return py::int_(1); } +py::object getitem_impl(py::capsule context, py::capsule ssa_val, py::handle index) +{ + auto& ctx = get_py_context(context); + auto value = unwrap_mlir(ssa_val); + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto index_val = index.cast(); + auto type = value.getType(); + if (auto tuple_type = type.dyn_cast()) + { + if (index_val < 0 || index_val >= static_cast(tuple_type.size())) + { + plier::report_error("Invelid getitem index"); + } + auto elem_type = tuple_type.getType(static_cast(index_val)); + auto ind = builder.create(loc, index_val); + auto item = builder.create(loc, elem_type, value, ind); + return ctx.context.create_var(context, item.getResult()); + } + else + { + if (0 != index_val) + { + plier::report_error("Invelid getitem index"); + } + return ctx.context.create_var(context, value); + } +} + void setup_py_var(pybind11::handle var) { py::setattr(var, "_shape", py::cpp_function(&shape_impl)); py::setattr(var, "_dtype", py::cpp_function(&dtype_impl)); py::setattr(var, "_len", py::cpp_function(&len_impl)); + py::setattr(var, "_getitem", py::cpp_function(&getitem_impl)); } PyLinalgResolver::Values unpack_results(py::handle object) @@ -648,5 +681,9 @@ llvm::Optional PyLinalgResolver::rewrite(llvm::StringR setup_py_builder(py_builder, builder); auto result = builder_func(py_builder, *py_args); + if (result.is_none()) + { + return {}; + } return unpack_results(result); } diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 1dbcb295607..98a861000f3 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -16,6 +16,12 @@ def dtype(self): def __len__(self): return self._len(self._context, self._ssa_val) + def __getitem__(self, index): + return self._getitem(self._context, self._ssa_val, index) + +def is_literal(val): + return not isinstance(val, Var) + class Builder: def __init__(self, context): self._context = context From a1bb4704e700e2394fb95650210188986d4b9d8d Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:41 +0300 Subject: [PATCH 14/15] add symbolDCE pass --- mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 3ff37c6b90d..6571258dd72 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -690,6 +690,7 @@ void PostLinalgOptPass::runOnOperation() void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); + pm.addPass(mlir::createSymbolDCEPass()); } void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) @@ -712,6 +713,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addPass(std::make_unique()); pm.addPass(std::make_unique()); + pm.addPass(mlir::createSymbolDCEPass()); } } From 1c25bf6d25b3b07212985a5b15908dca5532d05d Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 20 Feb 2021 17:01:41 +0300 Subject: [PATCH 15/15] numpy sum axis support --- .../src/pipelines/plier_to_linalg.cpp | 1 + .../mlir-compiler/src/py_linalg_resolver.cpp | 159 ++++++++++++------ numba/mlir/linalg_builder.py | 7 +- numba/mlir/numpy/funcs.py | 51 ++++-- numba/mlir/tests/test_numpy.py | 11 ++ 5 files changed, 163 insertions(+), 66 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 6571258dd72..9fec83eef42 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -606,6 +606,7 @@ struct LowerLinalgPass : mlir::DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 2bcd2d9f514..ec2ed44bd9a 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -59,36 +59,6 @@ auto unwrap_ssa_val(py::handle obj) return unwrap_mlir(obj.attr("_ssa_val").cast()); } -auto unwrap_shape(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value shape) -{ - llvm::SmallVector ret; - auto type = shape.getType(); - auto index_cast = [&](mlir::Value val)->mlir::Value - { - if (!val.getType().isa()) - { - return builder.create(loc, builder.getIndexType(), val); - } - return val; - }; - if (auto tuple_type = type.dyn_cast()) - { - ret.resize(tuple_type.size()); - for (auto it : llvm::enumerate(ret)) - { - auto elem_type = tuple_type.getType(it.index()); - auto index = builder.create(loc, it.index()); - auto dim = builder.create(loc, elem_type, shape, index); - it.value() = index_cast(dim); - } - } - else - { - ret.emplace_back(index_cast(shape)); - } - return ret; -} - size_t container_size(py::handle obj) { if (py::isinstance(obj)) @@ -139,6 +109,15 @@ llvm::Optional make_py_literal(mlir::Value val) return {}; } +mlir::Value do_cast(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value val, mlir::Type type) +{ + if (val.getType() != type) + { + return builder.create(loc, type, val); + } + return val; +} + void setup_py_var(py::handle var); } @@ -188,6 +167,25 @@ struct PyLinalgResolver::Context } return std::move(ret); } + + mlir::Value unwrap_val(mlir::Location loc, mlir::OpBuilder& builder, py::handle obj) + { + if (py::isinstance(obj, var)) + { + return unwrap_ssa_val(obj); + } + if (py::isinstance(obj)) + { + auto attr = builder.getI64IntegerAttr(obj.cast()); + return builder.create(loc, attr); + } + if (py::isinstance(obj)) + { + auto attr = builder.getF64FloatAttr(obj.cast()); + return builder.create(loc, attr); + } + plier::report_error("Invalid element type"); + } }; namespace @@ -252,31 +250,30 @@ PyBuilderContext& get_py_context(py::capsule& ctx) return *static_cast(ctx); } -mlir::Value get_var_value(py::handle var) -{ - return unwrap_mlir(var.attr("_ssa_val").cast()); -} - auto get_types(mlir::ValueRange values) { return values.getTypes(); } -auto get_agrs_from_tuple(py::handle args) +auto get_agrs_from_tuple(py::handle args, llvm::function_ref unpack) { llvm::SmallVector ret; + if (args.is_none()) + { + return ret; + } if (py::isinstance(args)) { auto tuple = args.cast(); ret.resize(tuple.size()); for (auto it : llvm::enumerate(tuple)) { - ret[it.index()] = get_var_value(it.value()); + ret[it.index()] = unpack(it.value()); } } else { - ret.emplace_back(get_var_value(args)); + ret.emplace_back(unpack(args)); } return ret; } @@ -349,31 +346,91 @@ py::object broadcast_impl(py::capsule /*context*/, py::tuple args) } } -py::object init_tensor_impl(py::capsule context, py::handle shape, py::capsule dtype) +py::object init_tensor_impl(py::capsule context, py::handle shape, py::capsule dtype, py::handle init_val) { auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; auto elem_type = unwrap_mlir(dtype); mlir::Value init; auto count = py::len(shape); if (count == 0) { - // TODO: undef - auto zero_val = plier::getZeroVal(elem_type); - assert(zero_val); - init = ctx.builder.create(ctx.loc, zero_val); + if (init_val.is_none()) + { + // TODO: undef + auto zero_val = plier::getZeroVal(elem_type); + assert(zero_val); + init = builder.create(loc, zero_val); + } + else + { + init = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type); + } } else { + auto index_type = builder.getIndexType(); llvm::SmallVector shape_val(count); for (size_t i = 0; i < count; ++i) { - shape_val[i] = unwrap_ssa_val(shape[py::int_(i)]); + shape_val[i] = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, shape[py::int_(i)]), index_type); + } + + if (init_val.is_none()) + { + init = builder.create(loc, shape_val, elem_type); + } + else + { + auto val = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/) + { + builder.create(loc, val); + }; + llvm::SmallVector shape(count, -1); + auto type = mlir::RankedTensorType::get(shape, elem_type); + init = builder.create(loc, type, shape_val, body); } - init = ctx.builder.create(ctx.loc, shape_val, elem_type); } return ctx.context.create_var(context, init); } +py::object fill_tensor_impl(py::capsule context, py::handle tensor, py::handle value) +{ + auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; + auto tensor_val = ctx.context.unwrap_val(loc, builder, tensor); + auto tensor_type = tensor_val.getType().cast(); + auto init_val = ctx.context.unwrap_val(loc, builder, value); + if (init_val.getType() != tensor_type.getElementType()) + { + init_val = builder.create(loc, tensor_type.getElementType(), init_val); + } + +// auto val = builder.create(loc, tensor_type, tensor_val, init_val); + auto rank = static_cast(tensor_type.getRank()); + mlir::AffineMap affine_maps[] = { + mlir::AffineMap::getMultiDimIdentityMap(rank, builder.getContext()), + }; + llvm::SmallVector iterators(rank, "parallel"); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) + { + assert(values.size() == 1); + builder.create(loc, init_val); + }; + auto val = builder.create( + loc, + tensor_type, + llvm::None, + tensor_val, + affine_maps, + iterators, + body); + return ctx.context.create_var(context, val.getResult(0)); +} + py::object generic_impl(py::capsule context, py::handle inputs, py::handle outputs, py::list iterators, py::list maps, py::handle body) { auto& ctx = get_py_context(context); @@ -381,8 +438,13 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu auto& builder = ctx.builder; auto& mlir_context = *builder.getContext(); - auto inputs_args = get_agrs_from_tuple(inputs); - auto output_args = get_agrs_from_tuple(outputs); + auto unpack = [&](py::handle obj)->mlir::Value + { + return ctx.context.unwrap_val(loc, builder, obj); + }; + + auto inputs_args = get_agrs_from_tuple(inputs, unpack); + auto output_args = get_agrs_from_tuple(outputs, unpack); auto ret_types = get_types(output_args); auto mlir_iterators = get_iterators(iterators, mlir_context); @@ -500,7 +562,7 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice plier::report_error("Invalid element type"); } }); - auto res = builder.create(loc, get_var_value(value), ind); + auto res = builder.create(loc, ctx.context.unwrap_val(loc, builder, value), ind); return ctx.context.create_var(context, res); } @@ -508,6 +570,7 @@ void setup_py_builder(py::handle builder, mlir::OpBuilder& b) { py::setattr(builder, "_broadcast", py::cpp_function(&broadcast_impl)); py::setattr(builder, "_init_tensor", py::cpp_function(&init_tensor_impl)); + py::setattr(builder, "_fill_tensor", py::cpp_function(&fill_tensor_impl)); py::setattr(builder, "_generic", py::cpp_function(&generic_impl)); py::setattr(builder, "_from_elements", py::cpp_function(&from_elements_impl)); py::setattr(builder, "_extract", py::cpp_function(&extract_impl)); diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 98a861000f3..7167a7211de 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -29,8 +29,11 @@ def __init__(self, context): def broadcast(self, *args): return self._broadcast(self._context, args) - def init_tensor(self, shape, dtype): - return self._init_tensor(self._context, shape, dtype) + def init_tensor(self, shape, dtype, init_val=None): + return self._init_tensor(self._context, shape, dtype, init_val) + + def fill_tensor(self, tensor, value): + return self._fill_tensor(self._context, tensor, value) def generic(self, inputs, outputs, iterators, maps, body): return self._generic(self._context, inputs, outputs, iterators, maps, body) diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index 42e682643be..24050d43437 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -1,4 +1,4 @@ -from ..linalg_builder import register_func +from ..linalg_builder import register_func, is_literal import numpy import math @@ -32,22 +32,41 @@ def body(a, b, c): @register_func('array.sum') @register_func('numpy.sum', numpy.sum) -def sum_impl(builder, arg): - shape = arg.shape +def sum_impl(builder, arg, axis=None): + if axis is None: + shape = arg.shape + num_dims = len(shape) + iterators = ['reduction' for _ in range(num_dims)] + dims = ','.join(['d%s' % i for i in range(num_dims)]) + expr1 = f'({dims}) -> ({dims})' + expr2 = f'({dims}) -> (0)' + maps = [expr1,expr2] + init = builder.from_elements(0, arg.dtype) + + def body(a, b): + return a + b + + res = builder.generic(arg, init, iterators, maps, body) + return builder.extract(res, 0) + elif isinstance(axis, int): + shape = arg.shape + num_dims = len(shape) + iterators = [('reduction' if i == axis else 'parallel') for i in range(num_dims)] + dims1 = ','.join(['d%s' % i for i in range(num_dims)]) + dims2 = ','.join(['d%s' % i for i in range(num_dims) if i != axis]) + expr1 = f'({dims1}) -> ({dims1})' + expr2 = f'({dims1}) -> ({dims2})' + maps = [expr1,expr2] + res_shape = tuple(shape[i] for i in range(len(shape)) if i != axis) + + init = builder.init_tensor(res_shape, builder.int64, 0) #TODO: type + # val = builder.fill_tensor(init, 0) + + def body(a, b): + return a + b + + return builder.generic(arg, init, iterators, maps, body) - num_dims = len(shape) - iterators = ['reduction' for _ in range(num_dims)] - dims = ','.join(['d%s' % i for i in range(num_dims)]) - expr1 = f'({dims}) -> ({dims})' - expr2 = f'({dims}) -> (0)' - maps = [expr1,expr2] - init = builder.from_elements(0, arg.dtype) - - def body(a, b): - return a + b - - res = builder.generic(arg, init, iterators, maps, body) - return builder.extract(res, 0) @register_func('numpy.sqrt', numpy.sqrt) def sqrt_impl(builder, arg): diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index a286e68c72d..405d5cdc019 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -51,6 +51,17 @@ def test_unary(self): arr = np.array(a) assert_equal(py_func(arr), jit_func(arr)) + def test_sum_axis(self): + funcs = [ + lambda a: np.sum(a, axis=0), + lambda a: np.sum(a, axis=1), + ] + + for py_func in funcs: + jit_func = njit(py_func) + arr = np.array([[1,2,3],[4,5,6]]) + assert_equal(py_func(arr), jit_func(arr)) + def test_add(self): def py_func(a, b): return np.add(a, b)