Skip to content

Commit 7fd91bb

Browse files
authored
[mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars (llvm#148055)
This aims to expand the the MemRefToEmitC pass so that it can accept global scalars. From: ``` memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1> func.func @Globals() { memref.get_global @__constant_xi32 : memref<i32> } ``` To: ``` emitc.global static const @__constant_xi32 : i32 = -1 emitc.func @Globals() { %0 = get_global @__constant_xi32 : !emitc.lvalue<i32> %1 = apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32> return } ```
1 parent ff225b5 commit 7fd91bb

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/IR/TypeRange.h"
2022
#include "mlir/Transforms/DialectConversion.h"
2123

2224
using namespace mlir;
@@ -77,13 +79,23 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
7779
}
7880
};
7981

82+
Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
83+
Type resultTy;
84+
if (opTy.getRank() == 0) {
85+
resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy));
86+
} else {
87+
resultTy = typeConverter->convertType(opTy);
88+
}
89+
return resultTy;
90+
}
91+
8092
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
8193
using OpConversionPattern::OpConversionPattern;
8294

8395
LogicalResult
8496
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
8597
ConversionPatternRewriter &rewriter) const override {
86-
98+
MemRefType opTy = op.getType();
8799
if (!op.getType().hasStaticShape()) {
88100
return rewriter.notifyMatchFailure(
89101
op.getLoc(), "cannot transform global with dynamic shape");
@@ -95,7 +107,9 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
95107
op.getLoc(), "global variable with alignment requirement is "
96108
"currently not supported");
97109
}
98-
auto resultTy = getTypeConverter()->convertType(op.getType());
110+
111+
Type resultTy = convertMemRefType(opTy, getTypeConverter());
112+
99113
if (!resultTy) {
100114
return rewriter.notifyMatchFailure(op.getLoc(),
101115
"cannot convert result type");
@@ -114,6 +128,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
114128
bool externSpecifier = !staticSpecifier;
115129

116130
Attribute initialValue = operands.getInitialValueAttr();
131+
if (opTy.getRank() == 0) {
132+
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
133+
initialValue = elementsAttr.getSplatValue<Attribute>();
134+
}
117135
if (isa_and_present<UnitAttr>(initialValue))
118136
initialValue = {};
119137

@@ -132,11 +150,23 @@ struct ConvertGetGlobal final
132150
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
133151
ConversionPatternRewriter &rewriter) const override {
134152

135-
auto resultTy = getTypeConverter()->convertType(op.getType());
153+
MemRefType opTy = op.getType();
154+
Type resultTy = convertMemRefType(opTy, getTypeConverter());
155+
136156
if (!resultTy) {
137157
return rewriter.notifyMatchFailure(op.getLoc(),
138158
"cannot convert result type");
139159
}
160+
161+
if (opTy.getRank() == 0) {
162+
emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
163+
emitc::GetGlobalOp globalLValue = rewriter.create<emitc::GetGlobalOp>(
164+
op.getLoc(), lvalueType, operands.getNameAttr());
165+
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
166+
rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
167+
op, pointerType, rewriter.getStringAttr("&"), globalLValue);
168+
return success();
169+
}
140170
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
141171
operands.getNameAttr());
142172
return success();

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
4141
module @globals {
4242
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
4343
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
44+
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
45+
// CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
4446
memref.global @public_global : memref<3x7xf32>
4547
// CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
4648
memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,9 @@ module @globals {
5052
func.func @use_global() {
5153
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
5254
%0 = memref.get_global @public_global : memref<3x7xf32>
55+
// CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
56+
// CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
57+
%1 = memref.get_global @__constant_xi32 : memref<i32>
5358
return
5459
}
5560
}

0 commit comments

Comments
 (0)