16
16
#include " mlir/Dialect/EmitC/IR/EmitC.h"
17
17
#include " mlir/Dialect/MemRef/IR/MemRef.h"
18
18
#include " mlir/IR/Builders.h"
19
+ #include " mlir/IR/BuiltinTypes.h"
19
20
#include " mlir/IR/PatternMatch.h"
21
+ #include " mlir/IR/TypeRange.h"
20
22
#include " mlir/Transforms/DialectConversion.h"
21
23
22
24
using namespace mlir ;
@@ -77,13 +79,23 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
77
79
}
78
80
};
79
81
82
+ Type convertMemRefType (MemRefType opTy, const TypeConverter *typeConverter) {
83
+ Type resultTy;
84
+ if (opTy.getRank () == 0 ) {
85
+ resultTy = typeConverter->convertType (mlir::getElementTypeOrSelf (opTy));
86
+ } else {
87
+ resultTy = typeConverter->convertType (opTy);
88
+ }
89
+ return resultTy;
90
+ }
91
+
80
92
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
81
93
using OpConversionPattern::OpConversionPattern;
82
94
83
95
LogicalResult
84
96
matchAndRewrite (memref::GlobalOp op, OpAdaptor operands,
85
97
ConversionPatternRewriter &rewriter) const override {
86
-
98
+ MemRefType opTy = op. getType ();
87
99
if (!op.getType ().hasStaticShape ()) {
88
100
return rewriter.notifyMatchFailure (
89
101
op.getLoc (), " cannot transform global with dynamic shape" );
@@ -95,7 +107,9 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
95
107
op.getLoc (), " global variable with alignment requirement is "
96
108
" currently not supported" );
97
109
}
98
- auto resultTy = getTypeConverter ()->convertType (op.getType ());
110
+
111
+ Type resultTy = convertMemRefType (opTy, getTypeConverter ());
112
+
99
113
if (!resultTy) {
100
114
return rewriter.notifyMatchFailure (op.getLoc (),
101
115
" cannot convert result type" );
@@ -114,6 +128,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
114
128
bool externSpecifier = !staticSpecifier;
115
129
116
130
Attribute initialValue = operands.getInitialValueAttr ();
131
+ if (opTy.getRank () == 0 ) {
132
+ auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue ());
133
+ initialValue = elementsAttr.getSplatValue <Attribute>();
134
+ }
117
135
if (isa_and_present<UnitAttr>(initialValue))
118
136
initialValue = {};
119
137
@@ -132,11 +150,23 @@ struct ConvertGetGlobal final
132
150
matchAndRewrite (memref::GetGlobalOp op, OpAdaptor operands,
133
151
ConversionPatternRewriter &rewriter) const override {
134
152
135
- auto resultTy = getTypeConverter ()->convertType (op.getType ());
153
+ MemRefType opTy = op.getType ();
154
+ Type resultTy = convertMemRefType (opTy, getTypeConverter ());
155
+
136
156
if (!resultTy) {
137
157
return rewriter.notifyMatchFailure (op.getLoc (),
138
158
" cannot convert result type" );
139
159
}
160
+
161
+ if (opTy.getRank () == 0 ) {
162
+ emitc::LValueType lvalueType = emitc::LValueType::get (resultTy);
163
+ emitc::GetGlobalOp globalLValue = rewriter.create <emitc::GetGlobalOp>(
164
+ op.getLoc (), lvalueType, operands.getNameAttr ());
165
+ emitc::PointerType pointerType = emitc::PointerType::get (resultTy);
166
+ rewriter.replaceOpWithNewOp <emitc::ApplyOp>(
167
+ op, pointerType, rewriter.getStringAttr (" &" ), globalLValue);
168
+ return success ();
169
+ }
140
170
rewriter.replaceOpWithNewOp <emitc::GetGlobalOp>(op, resultTy,
141
171
operands.getNameAttr ());
142
172
return success ();
0 commit comments