Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.

[MLIR] Some linalg fixes and proper increfs/decrefs for array arguments #199

Merged
merged 5 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,12 +610,6 @@ struct ApplyFastmathFlags : public mlir::OpRewritePattern<Op>
};

// Copypaste from StandardToLLVM
mlir::Value createIndexAttrConstant(mlir::OpBuilder &builder, mlir::Location loc,
mlir::Type resultType, int64_t value) {
return builder.create<mlir::LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
}

struct AllocLikeOpLowering : public mlir::ConvertToLLVMPattern {
using ConvertToLLVMPattern::createIndexConstant;
using ConvertToLLVMPattern::getIndexType;
Expand All @@ -625,19 +619,6 @@ struct AllocLikeOpLowering : public mlir::ConvertToLLVMPattern {
: ConvertToLLVMPattern(opName, &converter.getContext(), converter, /*benefit*/99) {}

protected:
// Returns 'input' aligned up to 'alignment'. Computes
// bumped = input + alignement - 1
// aligned = bumped - bumped % alignment
// static mlir::Value createAligned(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc,
// mlir::Value input, mlir::Value alignment) {
// using namespace mlir;
// Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1);
// Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
// Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
// Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
// return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
// }

// Creates a call to an allocation function with params and casts the
// resulting void pointer to ptrType.
mlir::Value createAllocCall(mlir::Location loc, mlir::StringRef name, mlir::Type ptrType,
Expand Down Expand Up @@ -1227,6 +1208,51 @@ struct PostLLVMLowering :
}
};

struct LowerRetain : public mlir::OpConversionPattern<plier::RetainOp>
{
using mlir::OpConversionPattern<plier::RetainOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(plier::RetainOp op, llvm::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 1);
auto arg = operands[0];
if (!arg.getType().isa<mlir::LLVM::LLVMStructType>())
{
return mlir::failure();
}

auto llvmVoidPointerType =
mlir::LLVM::LLVMPointerType::get(rewriter.getIntegerType(8));
auto incref_func = [&]()
{
auto mod = op->getParentOfType<mlir::ModuleOp>();
assert(mod);
auto func = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>("NRT_incref");
if (!func)
{
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(mod.getBody());
auto llvmVoidType = mlir::LLVM::LLVMVoidType::get(rewriter.getContext());
func = rewriter.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), "NRT_incref",
mlir::LLVM::LLVMFunctionType::get(llvmVoidType, llvmVoidPointerType));
}
return func;
}();

auto loc = op.getLoc();
auto index = rewriter.getI64ArrayAttr(0);
auto elemType = arg.getType().cast<mlir::LLVM::LLVMStructType>().getBody()[0];
mlir::Value ptr = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, elemType, arg, index);
ptr = rewriter.create<mlir::LLVM::BitcastOp>(loc, llvmVoidPointerType, ptr);
rewriter.create<mlir::LLVM::CallOp>(loc, incref_func, ptr);
rewriter.replaceOp(op, arg);

return mlir::success();
}
};

struct LowerCasts : public mlir::OpConversionPattern<plier::CastOp>
{
using mlir::OpConversionPattern<plier::CastOp>::OpConversionPattern;
Expand Down Expand Up @@ -1277,7 +1303,7 @@ struct LLVMLoweringPass : public mlir::PassWrapper<LLVMLoweringPass, mlir::Opera

OwningRewritePatternList patterns;
populateStdToLLVMConversionPatterns(typeConverter, patterns);
patterns.insert<LowerCasts>(typeConverter, &getContext());
patterns.insert<LowerCasts, LowerRetain>(typeConverter, &getContext());
patterns.insert<AllocOpLowering, DeallocOpLowering>(typeConverter);

LLVMConversionTarget target(getContext());
Expand Down
49 changes: 43 additions & 6 deletions mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,7 @@ void PlierToLinalgPass::runOnOperation()
patterns.insert<
GetitemOpLowering<plier::GetItemOp>,
GetitemOpLowering<plier::StaticGetItemOp>,
SetitemOpLowering<plier::SetItemOp>,
plier::ForceInline
SetitemOpLowering<plier::SetItemOp>
>(&getContext());

// range/prange lowering need dead branch pruning to properly
Expand Down Expand Up @@ -802,8 +801,8 @@ void LowerLinalgPass::runOnOperation()
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

struct PostFusionOptPass :
public mlir::PassWrapper<PostFusionOptPass, mlir::OperationPass<mlir::ModuleOp>>
struct CommonOptPass :
public mlir::PassWrapper<CommonOptPass, mlir::OperationPass<mlir::ModuleOp>>
{
virtual void getDependentDialects(
mlir::DialectRegistry &registry) const override
Expand All @@ -817,7 +816,7 @@ struct PostFusionOptPass :
void runOnOperation() override;
};

void PostFusionOptPass::runOnOperation()
void CommonOptPass::runOnOperation()
{
mlir::OwningRewritePatternList patterns;

Expand All @@ -829,6 +828,7 @@ void PostFusionOptPass::runOnOperation()

patterns.insert<
// LoopInvariantCodeMotion, TODO
plier::ForceInline,
plier::CSERewrite<mlir::FuncOp>
>(&context);

Expand Down Expand Up @@ -859,6 +859,41 @@ struct LoopInvariantCodeMotion : public mlir::OpRewritePattern<mlir::scf::ForOp>
}
};

struct RetainArgsPass :
public mlir::PassWrapper<RetainArgsPass, mlir::FunctionPass>
{
virtual void getDependentDialects(
mlir::DialectRegistry &registry) const override
{
registry.insert<plier::PlierDialect>();
}

void runOnFunction() override;
};

void RetainArgsPass::runOnFunction()
{
auto func = getFunction();
if (func.isPrivate() || func.isDeclaration() || func.body().empty())
{
return;
}

mlir::OpBuilder builder(&getContext());
auto loc = builder.getUnknownLoc();
auto block = &func.body().front();
builder.setInsertionPointToStart(block);
for (auto arg : block->getArguments())
{
if (arg.getType().isa<mlir::MemRefType>())
{
auto retained = builder.create<plier::RetainOp>(loc, arg);
llvm::SmallPtrSet<mlir::Operation*, 1> except({retained});
arg.replaceAllUsesExcept(retained, except);
}
}
}

struct PostLinalgOptPass :
public mlir::PassWrapper<PostLinalgOptPass, mlir::OperationPass<mlir::ModuleOp>>
{
Expand Down Expand Up @@ -900,13 +935,14 @@ void PostLinalgOptPass::runOnOperation()
void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm)
{
pm.addPass(std::make_unique<PlierToLinalgPass>());
pm.addPass(std::make_unique<CommonOptPass>());
pm.addPass(mlir::createSymbolDCEPass());
}

void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
{
pm.addPass(mlir::createLinalgFusionOfTensorOpsPass());
pm.addPass(std::make_unique<PostFusionOptPass>());
pm.addPass(std::make_unique<CommonOptPass>());

pm.addPass(mlir::createTensorConstantBufferizePass());
pm.addNestedPass<mlir::FuncOp>(mlir::createSCFBufferizePass());
Expand All @@ -920,6 +956,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm)
pm.addNestedPass<mlir::FuncOp>(mlir::createBufferLoopHoistingPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createPromoteBuffersToStackPass());

pm.addNestedPass<mlir::FuncOp>(std::make_unique<RetainArgsPass>());
pm.addNestedPass<mlir::FuncOp>(mlir::createBufferDeallocationPass());

pm.addPass(std::make_unique<LowerLinalgPass>());
Expand Down
32 changes: 21 additions & 11 deletions mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ void container_iterate(py::handle obj, F&& func)

llvm::Optional<py::object> make_py_literal(mlir::Value val)
{
assert(val);
if (auto int_val = plier::getConstVal<mlir::IntegerAttr>(val))
{
return py::int_(int_val.getInt());
Expand Down Expand Up @@ -144,6 +145,7 @@ struct PyLinalgResolver::Context

py::object create_var(py::capsule context, mlir::Value value)
{
assert(value);
if (auto literal = make_py_literal(value))
{
return *literal;
Expand Down Expand Up @@ -423,19 +425,17 @@ mlir::Value broadcast_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Va
return builder.create<mlir::SelectOp>(loc, cond, val2, val1);
}

mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, unsigned dim, mlir::ValueRange target_shape)
mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value initial, mlir::Value src, unsigned dim, mlir::ValueRange target_shape)
{
auto context = builder.getContext();
auto src_type = src.getType().cast<mlir::ShapedType>();
auto num_dims = static_cast<unsigned>(src_type.getRank());
auto shape = llvm::to_vector<8>(src_type.getShape());
shape[dim] = -1;
mlir::Type target_type = mlir::RankedTensorType::get(shape, src_type.getElementType());
auto dim_val = builder.create<mlir::DimOp>(loc, src, dim);
auto dim_val = builder.create<mlir::DimOp>(loc, initial, dim);
auto one = builder.create<mlir::ConstantIndexOp>(loc, 1);
mlir::Value cond = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::eq, one, dim_val);
mlir::Value cond2 = builder.create<mlir::CmpIOp>(loc, mlir::CmpIPredicate::ne, target_shape[dim], dim_val);
cond = builder.create<mlir::AndOp>(loc, cond, cond2);
llvm::SmallVector<mlir::Value> new_shape(num_dims);
for (unsigned i = 0 ; i < num_dims; ++i)
{
Expand Down Expand Up @@ -498,11 +498,12 @@ mlir::Value expand_dims(mlir::OpBuilder& builder, mlir::Location loc, mlir::Valu
{
target_shape = target_shape.drop_front(target_shape.size() - num_dims);
}
mlir::Value current = val;
for (unsigned i = 0; i < num_dims; ++i)
{
val = expand_dim(builder, loc, val, i, target_shape);
current = expand_dim(builder, loc, val, current, i, target_shape);
}
return val;
return current;
}

py::object broadcast_impl(py::capsule context, py::tuple args)
Expand Down Expand Up @@ -632,11 +633,20 @@ py::object broadcast_impl(py::capsule context, py::tuple args)
{
val = builder.create<plier::CastOp>(loc, tensor_type.getElementType(), val);
}
auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/)
val = builder.create<mlir::tensor::FromElementsOp>(loc, val);
auto num_dims = static_cast<unsigned>(tensor_type.getRank());
auto init = builder.create<mlir::linalg::InitTensorOp>(loc, shape_vals, tensor_type.getElementType()).getResult();
mlir::AffineMap maps[] = {
mlir::AffineMap::get(num_dims, 0, mlir::getAffineConstantExpr(0, builder.getContext())),
mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()),
};
llvm::SmallVector<llvm::StringRef> iterators(num_dims, "parallel");
auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values)
{
builder.create<mlir::tensor::YieldOp>(loc, val);
assert(values.size() == 2);
builder.create<mlir::linalg::YieldOp>(loc, values[0]);
};
val = builder.create<mlir::tensor::GenerateOp>(loc, tensor_type, shape_vals, body);
val = builder.create<mlir::linalg::GenericOp>(loc, tensor_type, val, init, maps, iterators, body).getResult(0);
}
}
ret[it.index()] = ctx.context.create_var(context, val);
Expand Down Expand Up @@ -688,12 +698,12 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt
else
{
auto val = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type);
llvm::SmallVector<int64_t> shape(count, -1);
auto type = mlir::RankedTensorType::get(shape, elem_type);
auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/)
{
builder.create<mlir::tensor::YieldOp>(loc, val);
};
llvm::SmallVector<int64_t> shape(count, -1);
auto type = mlir::RankedTensorType::get(shape, elem_type);
init = builder.create<mlir::tensor::GenerateOp>(loc, type, shape_val, body);
}
if (llvm::any_of(static_shape, [](auto val){ return val >= 0;}))
Expand Down
10 changes: 10 additions & 0 deletions mlir-compiler/plier/include/plier/PlierOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> {
];
}

def RetainOp : Plier_Op<"retain"> {
let arguments = (ins AnyMemRef:$value);

let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$memref);

let builders = [
OpBuilderDAG<(ins "::mlir::Value":$value)>
];
}

def ParallelOp : Plier_Op<"parallel",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
Expand Down
2 changes: 2 additions & 0 deletions mlir-compiler/plier/include/plier/dialect.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Types.h>
#include <mlir/IR/OpDefinition.h>
Expand All @@ -14,6 +15,7 @@ using Value = ::mlir::Value;
using Region = ::mlir::Region;
using LogicalResult = ::mlir::LogicalResult;
using Operation = ::mlir::Operation;
namespace MemoryEffects = ::mlir::MemoryEffects;

template<typename T>
using ArrayRef = ::mlir::ArrayRef<T>;
Expand Down
5 changes: 5 additions & 0 deletions mlir-compiler/plier/src/dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ void GetattrOp::getCanonicalizationPatterns(
results.insert<GetattrGlobalRewrite>(context);
}

void RetainOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::Value value) {
RetainOp::build(builder, state, value.getType(), value);
}

mlir::LogicalResult ParallelOp::moveOutOfLoop(mlir::ArrayRef<mlir::Operation *> ops)
{
for (mlir::Operation *op : ops)
Expand Down