Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

[MLIR] Numpy dot, sum with axis, array.size, array.T #188

Merged
merged 9 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 97 additions & 12 deletions mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,25 @@

namespace
{
bool parse_layout(llvm::StringRef& name)
enum class ArrayLayout
{
return name.consume_back("C"); // TODO
C,
F
};

bool parse_layout(llvm::StringRef& name, ArrayLayout& layout)
{
if (name.consume_back("C"))
{
layout = ArrayLayout::C;
return true;
}
if (name.consume_back("F"))
{
layout = ArrayLayout::F;
return true;
}
return false;
}

template<typename T>
Expand All @@ -67,24 +83,43 @@ bool consume_int_back(llvm::StringRef& name, T& result)
return false;
}

mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter,
llvm::StringRef& name)
struct ArrayDesc
{
unsigned dims = 0;
ArrayLayout layout = {};
llvm::StringRef name;
};

llvm::Optional<ArrayDesc> parse_array_desc(llvm::StringRef& name)
{
unsigned num_dims = 0;
ArrayLayout layout = {};
if (name.consume_front("array(") &&
name.consume_back(")") &&
parse_layout(name) &&
parse_layout(name, layout) &&
name.consume_back(", ") &&
name.consume_back("d") &&
consume_int_back(name, num_dims) &&
name.consume_back(", ") &&
!name.empty())
{
if (auto type = conveter.convertType(plier::PyType::get(&ctx, name)))
return ArrayDesc{num_dims, layout, name};
}
return {};
}

mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter,
llvm::StringRef& name)
{
if (auto desc = parse_array_desc(name))
{
if (desc->layout == ArrayLayout::C)
{
llvm::SmallVector<int64_t, 8> shape(num_dims, -1);
// return mlir::MemRefType::get(shape, type);
return mlir::RankedTensorType::get(shape, type);
if (auto type = conveter.convertType(plier::PyType::get(&ctx, desc->name)))
{
llvm::SmallVector<int64_t, 8> shape(desc->dims, -1);
return mlir::RankedTensorType::get(shape, type);
}
}
}
return nullptr;
Expand Down Expand Up @@ -181,7 +216,8 @@ struct CallLowerer
{
mlir::LogicalResult operator()(
plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef<mlir::Value> args,
llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs, mlir::PatternRewriter& rewriter)
llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs,
mlir::PatternRewriter& rewriter)
{
using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef<mlir::Value>, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>>, mlir::PatternRewriter&);
std::pair<llvm::StringRef, func_t> handlers[] = {
Expand All @@ -195,7 +231,7 @@ struct CallLowerer
}
}

if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args, kwargs))
if (auto result = linalg_resolver.rewrite_func(name, op.getLoc(), rewriter, args, kwargs))
{
assert(result->size() == op->getNumResults());
rerun_std_pipeline(op);
Expand All @@ -222,6 +258,32 @@ struct CallLowerer
return mlir::failure();
}

mlir::LogicalResult operator()(
plier::GetattrOp op, llvm::StringRef name, mlir::Value arg,
mlir::PatternRewriter& rewriter)
{
if (!arg.getType().isa<mlir::ShapedType>())
{
return mlir::failure();
}
auto full_name = (llvm::Twine("array.") + name).str();
if (auto result = linalg_resolver.rewrite_attr(full_name, op.getLoc(), rewriter, arg))
{
assert(result->size() == op->getNumResults());
rerun_std_pipeline(op);
if (result->empty())
{
rewriter.eraseOp(op);
}
else
{
rewriter.replaceOp(op, *result);
}
return mlir::success();
}
return mlir::failure();
}

private:
PyLinalgResolver linalg_resolver;
};
Expand Down Expand Up @@ -550,6 +612,28 @@ struct ArrayShape : public mlir::OpRewritePattern<plier::GetattrOp>
mlir::TypeConverter& converter;
};

struct GetattrRewriter : public mlir::OpRewritePattern<plier::GetattrOp>
{
using resolver_t = llvm::function_ref<mlir::LogicalResult(plier::GetattrOp, llvm::StringRef, mlir::Value,
mlir::PatternRewriter&)>;

GetattrRewriter(mlir::TypeConverter &/*typeConverter*/,
mlir::MLIRContext *context,
resolver_t resolver):
OpRewritePattern(context),
resolver(resolver)
{}

mlir::LogicalResult matchAndRewrite(
plier::GetattrOp op, mlir::PatternRewriter &rewriter) const override
{
return resolver(op, op.name(), op.value(), rewriter);
}

private:
resolver_t resolver;
};


void PlierToLinalgPass::runOnOperation()
{
Expand Down Expand Up @@ -579,7 +663,8 @@ void PlierToLinalgPass::runOnOperation()
CallLowerer callLowerer;

patterns.insert<
plier::CallOpLowering
plier::CallOpLowering,
GetattrRewriter
>(type_converter, context, callLowerer);

patterns.insert<
Expand Down
101 changes: 90 additions & 11 deletions mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ auto unwrap_ssa_val(py::handle obj)
return unwrap_mlir<mlir::Value>(obj.attr("_ssa_val").cast<py::capsule>());
}

auto unwrap_type(py::handle obj)
{
return unwrap_mlir<mlir::Type>(obj.attr("_mlir_type").cast<py::capsule>());
}

size_t container_size(py::handle obj)
{
if (py::isinstance<py::tuple>(obj))
Expand Down Expand Up @@ -118,12 +123,18 @@ mlir::Value do_cast(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value va
return val;
}

bool cmp_capsule(py::capsule a1, py::capsule a2)
{
return static_cast<void*>(a1) == static_cast<void*>(a2);
}

void setup_py_var(py::handle var);
}

struct PyLinalgResolver::Context
{
py::handle var;
py::handle type;
py::handle builder;
py::handle inspect;
py::handle types_mod;
Expand All @@ -141,6 +152,11 @@ struct PyLinalgResolver::Context
return ret;
}

py::object create_type(mlir::Type t)
{
return type(wrap_mlir(t), py::cpp_function(&cmp_capsule));
}

mlir::FuncOp compile_body(py::handle body, py::list arg_types)
{
auto func = compile_func(body, arg_types).cast<py::capsule>();
Expand Down Expand Up @@ -346,12 +362,12 @@ py::object broadcast_impl(py::capsule /*context*/, py::tuple args)
}
}

py::object init_tensor_impl(py::capsule context, py::handle shape, py::capsule dtype, py::handle init_val)
py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dtype, py::handle init_val)
{
auto& ctx = get_py_context(context);
auto loc = ctx.loc;
auto& builder = ctx.builder;
auto elem_type = unwrap_mlir<mlir::Type>(dtype);
auto elem_type = unwrap_type(dtype);
mlir::Value init;
auto count = py::len(shape);
if (count == 0)
Expand Down Expand Up @@ -500,12 +516,12 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu
}
}

py::object from_elements_impl(py::capsule context, py::handle values, py::capsule dtype)
py::object from_elements_impl(py::capsule context, py::handle values, py::handle dtype)
{
auto& ctx = get_py_context(context);
auto& builder = ctx.builder;
auto loc = ctx.loc;
auto type = unwrap_mlir<mlir::Type>(dtype);
auto type = unwrap_type(dtype);

llvm::SmallVector<mlir::Value, 8> vals(container_size(values));
container_iterate(values, [&](auto index, py::handle obj)
Expand Down Expand Up @@ -566,7 +582,7 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice
return ctx.context.create_var(context, res);
}

void setup_py_builder(py::handle builder, mlir::OpBuilder& b)
void setup_py_builder(py::handle builder, mlir::OpBuilder& b, llvm::function_ref<py::object(mlir::Type)> create_type)
{
py::setattr(builder, "_broadcast", py::cpp_function(&broadcast_impl));
py::setattr(builder, "_init_tensor", py::cpp_function(&init_tensor_impl));
Expand All @@ -577,13 +593,14 @@ void setup_py_builder(py::handle builder, mlir::OpBuilder& b)

auto add_type = [&](const char* name, mlir::Type type)
{
py::setattr(builder, name, wrap_mlir(type));
py::setattr(builder, name, create_type(type));
};

add_type("int8", b.getIntegerType(8));
add_type("int16", b.getIntegerType(16));
add_type("int32", b.getIntegerType(32));
add_type("int64", b.getIntegerType(64));
add_type("index", b.getIndexType());

add_type("float16", b.getF16Type());
add_type("float32", b.getF32Type());
Expand Down Expand Up @@ -615,15 +632,16 @@ py::object shape_impl(py::capsule context, py::capsule ssa_val)
return py::list();
}

py::object dtype_impl(py::capsule /*context*/, py::capsule ssa_val)
py::object dtype_impl(py::capsule context, py::capsule ssa_val)
{
auto& ctx = get_py_context(context);
auto value = unwrap_mlir<mlir::Value>(ssa_val);
auto type = value.getType();
if (auto tensor_type = type.dyn_cast<mlir::RankedTensorType>())
{
return wrap_mlir(tensor_type.getElementType());
return ctx.context.create_type(tensor_type.getElementType());
}
return wrap_mlir(type);
return ctx.context.create_type(type);
}

py::object len_impl(py::capsule /*context*/, py::capsule ssa_val)
Expand Down Expand Up @@ -666,12 +684,61 @@ py::object getitem_impl(py::capsule context, py::capsule ssa_val, py::handle ind
}
}

template<typename Op>
mlir::Value binop_func(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value lhs, mlir::Value rhs)
{
return builder.create<Op>(loc, lhs, rhs);
}

py::object binop_impl(py::capsule context, py::capsule ssa_val, py::handle rhs, py::str op)
{
auto& ctx = get_py_context(context);
auto& builder = ctx.builder;
auto loc = ctx.loc;
auto lhs = unwrap_mlir<mlir::Value>(ssa_val);

auto type = lhs.getType();
if (!type.isa<mlir::IntegerType, mlir::IndexType, mlir::FloatType, mlir::ShapedType>())
{
plier::report_error("Invalid binop arg type");
}

auto is_float = [&]()->bool
{
if (auto shaped_type = type.dyn_cast<mlir::ShapedType>())
{
return shaped_type.getElementType().isa<mlir::FloatType>();
}
return type.isa<mlir::FloatType>();
}();

using binop_func_t = mlir::Value(*)(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value lhs, mlir::Value rhs);
const std::tuple<llvm::StringRef, binop_func_t, binop_func_t> funcs[] = {
{"*", &binop_func<mlir::MulIOp>, &binop_func<mlir::MulFOp>},
};

auto op_name = static_cast<std::string>(op);
for (auto f : funcs)
{
auto name = std::get<0>(f);
auto func = (is_float ? std::get<2>(f) : std::get<1>(f));
if (name == op_name)
{
auto rhs_var = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, rhs), type);
auto res = func(loc, builder, lhs, rhs_var);
return ctx.context.create_var(context, res);
}
}
plier::report_error("Unhandled binop type");
}

void setup_py_var(pybind11::handle var)
{
py::setattr(var, "_shape", py::cpp_function(&shape_impl));
py::setattr(var, "_dtype", py::cpp_function(&dtype_impl));
py::setattr(var, "_len", py::cpp_function(&len_impl));
py::setattr(var, "_getitem", py::cpp_function(&getitem_impl));
py::setattr(var, "_binop", py::cpp_function(&binop_impl));
}

PyLinalgResolver::Values unpack_results(py::handle object)
Expand Down Expand Up @@ -701,6 +768,7 @@ PyLinalgResolver::PyLinalgResolver():
{
auto builder_mod = py::module::import("numba.mlir.linalg_builder");
context->var = builder_mod.attr("Var");
context->type = builder_mod.attr("Type");
context->builder = builder_mod.attr("Builder");
context->inspect = py::module::import("inspect");
context->types_mod = py::module::import("numba.core.types");
Expand All @@ -713,7 +781,18 @@ PyLinalgResolver::~PyLinalgResolver()

}

llvm::Optional<PyLinalgResolver::Values> PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, llvm::ArrayRef<std::pair<llvm::StringRef, mlir::Value>> kwargs)
llvm::Optional<PyLinalgResolver::Values> PyLinalgResolver::rewrite_func(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, KWArgs kwargs)
{
auto mangled_name = (llvm::Twine(name) + "()").str();
return rewrite(mangled_name, loc, builder, args, kwargs);
}

llvm::Optional<PyLinalgResolver::Values> PyLinalgResolver::rewrite_attr(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::Value arg)
{
return rewrite(name, loc, builder, arg, {});
}

llvm::Optional<PyLinalgResolver::Values> PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, KWArgs kwargs)
{
assert(!name.empty());
if (!is_compatible_types(args) ||
Expand Down Expand Up @@ -741,7 +820,7 @@ llvm::Optional<PyLinalgResolver::Values> PyLinalgResolver::rewrite(llvm::StringR
return {};
}
auto py_builder = context->builder(py_context);
setup_py_builder(py_builder, builder);
setup_py_builder(py_builder, builder, [&](auto type){ return context->create_type(type);});

auto result = builder_func(py_builder, *py_args);
if (result.is_none())
Expand Down
Loading