1//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert memref ops into emitc ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
14
15#include "mlir/Dialect/EmitC/IR/EmitC.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21using namespace mlir;
22
23namespace {
24struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
25 using OpConversionPattern::OpConversionPattern;
26
27 LogicalResult
28 matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
29 ConversionPatternRewriter &rewriter) const override {
30
31 if (!op.getType().hasStaticShape()) {
32 return rewriter.notifyMatchFailure(
33 op.getLoc(), "cannot transform alloca with dynamic shape");
34 }
35
36 if (op.getAlignment().value_or(1) > 1) {
37 // TODO: Allow alignment if it is not more than the natural alignment
38 // of the C array.
39 return rewriter.notifyMatchFailure(
40 op.getLoc(), "cannot transform alloca with alignment requirement");
41 }
42
43 auto resultTy = getTypeConverter()->convertType(op.getType());
44 if (!resultTy) {
45 return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
46 }
47 auto noInit = emitc::OpaqueAttr::get(getContext(), "");
48 rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
49 return success();
50 }
51};
52
53struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
54 using OpConversionPattern::OpConversionPattern;
55
56 LogicalResult
57 matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
58 ConversionPatternRewriter &rewriter) const override {
59
60 if (!op.getType().hasStaticShape()) {
61 return rewriter.notifyMatchFailure(
62 op.getLoc(), "cannot transform global with dynamic shape");
63 }
64
65 if (op.getAlignment().value_or(1) > 1) {
66 // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
67 return rewriter.notifyMatchFailure(
68 op.getLoc(), "global variable with alignment requirement is "
69 "currently not supported");
70 }
71 auto resultTy = getTypeConverter()->convertType(op.getType());
72 if (!resultTy) {
73 return rewriter.notifyMatchFailure(op.getLoc(),
74 "cannot convert result type");
75 }
76
77 SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(symbol: op);
78 if (visibility != SymbolTable::Visibility::Public &&
79 visibility != SymbolTable::Visibility::Private) {
80 return rewriter.notifyMatchFailure(
81 op.getLoc(),
82 "only public and private visibility is currently supported");
83 }
84 // We are explicit in specifing the linkage because the default linkage
85 // for constants is different in C and C++.
86 bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
87 bool externSpecifier = !staticSpecifier;
88
89 Attribute initialValue = operands.getInitialValueAttr();
90 if (isa_and_present<UnitAttr>(Val: initialValue))
91 initialValue = {};
92
93 rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
94 op, operands.getSymName(), resultTy, initialValue, externSpecifier,
95 staticSpecifier, operands.getConstant());
96 return success();
97 }
98};
99
100struct ConvertGetGlobal final
101 : public OpConversionPattern<memref::GetGlobalOp> {
102 using OpConversionPattern::OpConversionPattern;
103
104 LogicalResult
105 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
106 ConversionPatternRewriter &rewriter) const override {
107
108 auto resultTy = getTypeConverter()->convertType(op.getType());
109 if (!resultTy) {
110 return rewriter.notifyMatchFailure(op.getLoc(),
111 "cannot convert result type");
112 }
113 rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
114 operands.getNameAttr());
115 return success();
116 }
117};
118
119struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
120 using OpConversionPattern::OpConversionPattern;
121
122 LogicalResult
123 matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
124 ConversionPatternRewriter &rewriter) const override {
125
126 auto resultTy = getTypeConverter()->convertType(op.getType());
127 if (!resultTy) {
128 return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
129 }
130
131 auto arrayValue =
132 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
133 if (!arrayValue) {
134 return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
135 }
136
137 auto subscript = rewriter.create<emitc::SubscriptOp>(
138 op.getLoc(), arrayValue, operands.getIndices());
139
140 auto noInit = emitc::OpaqueAttr::get(getContext(), "");
141 auto var =
142 rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
143
144 rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
145 rewriter.replaceOp(op, var);
146 return success();
147 }
148};
149
150struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
151 using OpConversionPattern::OpConversionPattern;
152
153 LogicalResult
154 matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
155 ConversionPatternRewriter &rewriter) const override {
156 auto arrayValue =
157 dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
158 if (!arrayValue) {
159 return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
160 }
161
162 auto subscript = rewriter.create<emitc::SubscriptOp>(
163 op.getLoc(), arrayValue, operands.getIndices());
164 rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
165 operands.getValue());
166 return success();
167 }
168};
169} // namespace
170
171void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
172 typeConverter.addConversion(
173 callback: [&](MemRefType memRefType) -> std::optional<Type> {
174 if (!memRefType.hasStaticShape() ||
175 !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
176 return {};
177 }
178 Type convertedElementType =
179 typeConverter.convertType(memRefType.getElementType());
180 if (!convertedElementType)
181 return {};
182 return emitc::ArrayType::get(memRefType.getShape(),
183 convertedElementType);
184 });
185}
186
187void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
188 TypeConverter &converter) {
189 patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
190 ConvertStore>(arg&: converter, args: patterns.getContext());
191}
192

source code of mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp