1 | //===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
11 | #include "mlir/Dialect/EmitC/IR/EmitC.h" |
12 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
13 | #include "mlir/IR/BuiltinOps.h" |
14 | #include "mlir/IR/BuiltinTypes.h" |
15 | #include "mlir/IR/Dialect.h" |
16 | #include "mlir/IR/Operation.h" |
17 | #include "mlir/IR/SymbolTable.h" |
18 | #include "mlir/Support/IndentedOstream.h" |
19 | #include "mlir/Support/LLVM.h" |
20 | #include "mlir/Target/Cpp/CppEmitter.h" |
21 | #include "llvm/ADT/DenseMap.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/ADT/StringMap.h" |
24 | #include "llvm/ADT/TypeSwitch.h" |
25 | #include "llvm/Support/Debug.h" |
26 | #include "llvm/Support/FormatVariadic.h" |
27 | #include <utility> |
28 | |
29 | #define DEBUG_TYPE "translate-to-cpp" |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::emitc; |
33 | using llvm::formatv; |
34 | |
35 | /// Convenience functions to produce interleaved output with functions returning |
36 | /// a LogicalResult. This is different than those in STLExtras as functions used |
37 | /// on each element doesn't return a string. |
38 | template <typename ForwardIterator, typename UnaryFunctor, |
39 | typename NullaryFunctor> |
40 | inline LogicalResult |
41 | interleaveWithError(ForwardIterator begin, ForwardIterator end, |
42 | UnaryFunctor eachFn, NullaryFunctor betweenFn) { |
43 | if (begin == end) |
44 | return success(); |
45 | if (failed(eachFn(*begin))) |
46 | return failure(); |
47 | ++begin; |
48 | for (; begin != end; ++begin) { |
49 | betweenFn(); |
50 | if (failed(eachFn(*begin))) |
51 | return failure(); |
52 | } |
53 | return success(); |
54 | } |
55 | |
56 | template <typename Container, typename UnaryFunctor, typename NullaryFunctor> |
57 | inline LogicalResult interleaveWithError(const Container &c, |
58 | UnaryFunctor eachFn, |
59 | NullaryFunctor betweenFn) { |
60 | return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); |
61 | } |
62 | |
63 | template <typename Container, typename UnaryFunctor> |
64 | inline LogicalResult interleaveCommaWithError(const Container &c, |
65 | raw_ostream &os, |
66 | UnaryFunctor eachFn) { |
67 | return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", " ; }); |
68 | } |
69 | |
70 | /// Return the precedence of a operator as an integer, higher values |
71 | /// imply higher precedence. |
72 | static FailureOr<int> getOperatorPrecedence(Operation *operation) { |
73 | return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation) |
74 | .Case<emitc::AddOp>([&](auto op) { return 11; }) |
75 | .Case<emitc::ApplyOp>([&](auto op) { return 13; }) |
76 | .Case<emitc::CastOp>([&](auto op) { return 13; }) |
77 | .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> { |
78 | switch (op.getPredicate()) { |
79 | case emitc::CmpPredicate::eq: |
80 | case emitc::CmpPredicate::ne: |
81 | return 8; |
82 | case emitc::CmpPredicate::lt: |
83 | case emitc::CmpPredicate::le: |
84 | case emitc::CmpPredicate::gt: |
85 | case emitc::CmpPredicate::ge: |
86 | return 9; |
87 | case emitc::CmpPredicate::three_way: |
88 | return 10; |
89 | } |
90 | return op->emitError("unsupported cmp predicate" ); |
91 | }) |
92 | .Case<emitc::DivOp>([&](auto op) { return 12; }) |
93 | .Case<emitc::MulOp>([&](auto op) { return 12; }) |
94 | .Case<emitc::RemOp>([&](auto op) { return 12; }) |
95 | .Case<emitc::SubOp>([&](auto op) { return 11; }) |
96 | .Case<emitc::CallOpaqueOp>([&](auto op) { return 14; }) |
97 | .Default([](auto op) { return op->emitError("unsupported operation" ); }); |
98 | } |
99 | |
100 | namespace { |
101 | /// Emitter that uses dialect specific emitters to emit C++ code. |
102 | struct CppEmitter { |
103 | explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); |
104 | |
105 | /// Emits attribute or returns failure. |
106 | LogicalResult emitAttribute(Location loc, Attribute attr); |
107 | |
108 | /// Emits operation 'op' with/without training semicolon or returns failure. |
109 | LogicalResult emitOperation(Operation &op, bool trailingSemicolon); |
110 | |
111 | /// Emits type 'type' or returns failure. |
112 | LogicalResult emitType(Location loc, Type type); |
113 | |
114 | /// Emits array of types as a std::tuple of the emitted types. |
115 | /// - emits void for an empty array; |
116 | /// - emits the type of the only element for arrays of size one; |
117 | /// - emits a std::tuple otherwise; |
118 | LogicalResult emitTypes(Location loc, ArrayRef<Type> types); |
119 | |
120 | /// Emits array of types as a std::tuple of the emitted types independently of |
121 | /// the array size. |
122 | LogicalResult emitTupleType(Location loc, ArrayRef<Type> types); |
123 | |
124 | /// Emits an assignment for a variable which has been declared previously. |
125 | LogicalResult emitVariableAssignment(OpResult result); |
126 | |
127 | /// Emits a variable declaration for a result of an operation. |
128 | LogicalResult emitVariableDeclaration(OpResult result, |
129 | bool trailingSemicolon); |
130 | |
131 | /// Emits the variable declaration and assignment prefix for 'op'. |
132 | /// - emits separate variable followed by std::tie for multi-valued operation; |
133 | /// - emits single type followed by variable for single result; |
134 | /// - emits nothing if no value produced by op; |
135 | /// Emits final '=' operator where a type is produced. Returns failure if |
136 | /// any result type could not be converted. |
137 | LogicalResult emitAssignPrefix(Operation &op); |
138 | |
139 | /// Emits a label for the block. |
140 | LogicalResult emitLabel(Block &block); |
141 | |
142 | /// Emits the operands and atttributes of the operation. All operands are |
143 | /// emitted first and then all attributes in alphabetical order. |
144 | LogicalResult emitOperandsAndAttributes(Operation &op, |
145 | ArrayRef<StringRef> exclude = {}); |
146 | |
147 | /// Emits the operands of the operation. All operands are emitted in order. |
148 | LogicalResult emitOperands(Operation &op); |
149 | |
150 | /// Emits value as an operands of an operation |
151 | LogicalResult emitOperand(Value value); |
152 | |
153 | /// Emit an expression as a C expression. |
154 | LogicalResult emitExpression(ExpressionOp expressionOp); |
155 | |
156 | /// Return the existing or a new name for a Value. |
157 | StringRef getOrCreateName(Value val); |
158 | |
159 | /// Return the existing or a new label of a Block. |
160 | StringRef getOrCreateName(Block &block); |
161 | |
162 | /// Whether to map an mlir integer to a unsigned integer in C++. |
163 | bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); |
164 | |
165 | /// RAII helper function to manage entering/exiting C++ scopes. |
166 | struct Scope { |
167 | Scope(CppEmitter &emitter) |
168 | : valueMapperScope(emitter.valueMapper), |
169 | blockMapperScope(emitter.blockMapper), emitter(emitter) { |
170 | emitter.valueInScopeCount.push(x: emitter.valueInScopeCount.top()); |
171 | emitter.labelInScopeCount.push(x: emitter.labelInScopeCount.top()); |
172 | } |
173 | ~Scope() { |
174 | emitter.valueInScopeCount.pop(); |
175 | emitter.labelInScopeCount.pop(); |
176 | } |
177 | |
178 | private: |
179 | llvm::ScopedHashTableScope<Value, std::string> valueMapperScope; |
180 | llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope; |
181 | CppEmitter &emitter; |
182 | }; |
183 | |
184 | /// Returns wether the Value is assigned to a C++ variable in the scope. |
185 | bool hasValueInScope(Value val); |
186 | |
187 | // Returns whether a label is assigned to the block. |
188 | bool hasBlockLabel(Block &block); |
189 | |
190 | /// Returns the output stream. |
191 | raw_indented_ostream &ostream() { return os; }; |
192 | |
193 | /// Returns if all variables for op results and basic block arguments need to |
194 | /// be declared at the beginning of a function. |
195 | bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; |
196 | |
197 | /// Get expression currently being emitted. |
198 | ExpressionOp getEmittedExpression() { return emittedExpression; } |
199 | |
200 | /// Determine whether given value is part of the expression potentially being |
201 | /// emitted. |
202 | bool isPartOfCurrentExpression(Value value) { |
203 | if (!emittedExpression) |
204 | return false; |
205 | Operation *def = value.getDefiningOp(); |
206 | if (!def) |
207 | return false; |
208 | auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp()); |
209 | return operandExpression == emittedExpression; |
210 | }; |
211 | |
212 | private: |
213 | using ValueMapper = llvm::ScopedHashTable<Value, std::string>; |
214 | using BlockMapper = llvm::ScopedHashTable<Block *, std::string>; |
215 | |
216 | /// Output stream to emit to. |
217 | raw_indented_ostream os; |
218 | |
219 | /// Boolean to enforce that all variables for op results and block |
220 | /// arguments are declared at the beginning of the function. This also |
221 | /// includes results from ops located in nested regions. |
222 | bool declareVariablesAtTop; |
223 | |
224 | /// Map from value to name of C++ variable that contain the name. |
225 | ValueMapper valueMapper; |
226 | |
227 | /// Map from block to name of C++ label. |
228 | BlockMapper blockMapper; |
229 | |
230 | /// The number of values in the current scope. This is used to declare the |
231 | /// names of values in a scope. |
232 | std::stack<int64_t> valueInScopeCount; |
233 | std::stack<int64_t> labelInScopeCount; |
234 | |
235 | /// State of the current expression being emitted. |
236 | ExpressionOp emittedExpression; |
237 | SmallVector<int> emittedExpressionPrecedence; |
238 | |
239 | void pushExpressionPrecedence(int precedence) { |
240 | emittedExpressionPrecedence.push_back(Elt: precedence); |
241 | } |
242 | void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); } |
243 | static int lowestPrecedence() { return 0; } |
244 | int getExpressionPrecedence() { |
245 | if (emittedExpressionPrecedence.empty()) |
246 | return lowestPrecedence(); |
247 | return emittedExpressionPrecedence.back(); |
248 | } |
249 | }; |
250 | } // namespace |
251 | |
252 | /// Determine whether expression \p expressionOp should be emitted inline, i.e. |
253 | /// as part of its user. This function recommends inlining of any expressions |
254 | /// that can be inlined unless it is used by another expression, under the |
255 | /// assumption that any expression fusion/re-materialization was taken care of |
256 | /// by transformations run by the backend. |
257 | static bool shouldBeInlined(ExpressionOp expressionOp) { |
258 | // Do not inline if expression is marked as such. |
259 | if (expressionOp.getDoNotInline()) |
260 | return false; |
261 | |
262 | // Do not inline expressions with side effects to prevent side-effect |
263 | // reordering. |
264 | if (expressionOp.hasSideEffects()) |
265 | return false; |
266 | |
267 | // Do not inline expressions with multiple uses. |
268 | Value result = expressionOp.getResult(); |
269 | if (!result.hasOneUse()) |
270 | return false; |
271 | |
272 | // Do not inline expressions used by other expressions, as any desired |
273 | // expression folding was taken care of by transformations. |
274 | Operation *user = *result.getUsers().begin(); |
275 | return !user->getParentOfType<ExpressionOp>(); |
276 | } |
277 | |
278 | static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, |
279 | Attribute value) { |
280 | OpResult result = operation->getResult(idx: 0); |
281 | |
282 | // Only emit an assignment as the variable was already declared when printing |
283 | // the FuncOp. |
284 | if (emitter.shouldDeclareVariablesAtTop()) { |
285 | // Skip the assignment if the emitc.constant has no value. |
286 | if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) { |
287 | if (oAttr.getValue().empty()) |
288 | return success(); |
289 | } |
290 | |
291 | if (failed(result: emitter.emitVariableAssignment(result))) |
292 | return failure(); |
293 | return emitter.emitAttribute(loc: operation->getLoc(), attr: value); |
294 | } |
295 | |
296 | // Emit a variable declaration for an emitc.constant op without value. |
297 | if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) { |
298 | if (oAttr.getValue().empty()) |
299 | // The semicolon gets printed by the emitOperation function. |
300 | return emitter.emitVariableDeclaration(result, |
301 | /*trailingSemicolon=*/false); |
302 | } |
303 | |
304 | // Emit a variable declaration. |
305 | if (failed(result: emitter.emitAssignPrefix(op&: *operation))) |
306 | return failure(); |
307 | return emitter.emitAttribute(loc: operation->getLoc(), attr: value); |
308 | } |
309 | |
310 | static LogicalResult printOperation(CppEmitter &emitter, |
311 | emitc::ConstantOp constantOp) { |
312 | Operation *operation = constantOp.getOperation(); |
313 | Attribute value = constantOp.getValue(); |
314 | |
315 | return printConstantOp(emitter, operation, value); |
316 | } |
317 | |
318 | static LogicalResult printOperation(CppEmitter &emitter, |
319 | emitc::VariableOp variableOp) { |
320 | Operation *operation = variableOp.getOperation(); |
321 | Attribute value = variableOp.getValue(); |
322 | |
323 | return printConstantOp(emitter, operation, value); |
324 | } |
325 | |
326 | static LogicalResult printOperation(CppEmitter &emitter, |
327 | arith::ConstantOp constantOp) { |
328 | Operation *operation = constantOp.getOperation(); |
329 | Attribute value = constantOp.getValue(); |
330 | |
331 | return printConstantOp(emitter, operation, value); |
332 | } |
333 | |
334 | static LogicalResult printOperation(CppEmitter &emitter, |
335 | func::ConstantOp constantOp) { |
336 | Operation *operation = constantOp.getOperation(); |
337 | Attribute value = constantOp.getValueAttr(); |
338 | |
339 | return printConstantOp(emitter, operation, value); |
340 | } |
341 | |
342 | static LogicalResult printOperation(CppEmitter &emitter, |
343 | emitc::AssignOp assignOp) { |
344 | auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp()); |
345 | OpResult result = variableOp->getResult(0); |
346 | |
347 | if (failed(result: emitter.emitVariableAssignment(result))) |
348 | return failure(); |
349 | |
350 | return emitter.emitOperand(value: assignOp.getValue()); |
351 | } |
352 | |
353 | static LogicalResult printBinaryOperation(CppEmitter &emitter, |
354 | Operation *operation, |
355 | StringRef binaryOperator) { |
356 | raw_ostream &os = emitter.ostream(); |
357 | |
358 | if (failed(result: emitter.emitAssignPrefix(op&: *operation))) |
359 | return failure(); |
360 | |
361 | if (failed(result: emitter.emitOperand(value: operation->getOperand(idx: 0)))) |
362 | return failure(); |
363 | |
364 | os << " " << binaryOperator << " " ; |
365 | |
366 | if (failed(result: emitter.emitOperand(value: operation->getOperand(idx: 1)))) |
367 | return failure(); |
368 | |
369 | return success(); |
370 | } |
371 | |
372 | static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) { |
373 | Operation *operation = addOp.getOperation(); |
374 | |
375 | return printBinaryOperation(emitter, operation, binaryOperator: "+" ); |
376 | } |
377 | |
378 | static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) { |
379 | Operation *operation = divOp.getOperation(); |
380 | |
381 | return printBinaryOperation(emitter, operation, binaryOperator: "/" ); |
382 | } |
383 | |
384 | static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) { |
385 | Operation *operation = mulOp.getOperation(); |
386 | |
387 | return printBinaryOperation(emitter, operation, binaryOperator: "*" ); |
388 | } |
389 | |
390 | static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) { |
391 | Operation *operation = remOp.getOperation(); |
392 | |
393 | return printBinaryOperation(emitter, operation, binaryOperator: "%" ); |
394 | } |
395 | |
396 | static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) { |
397 | Operation *operation = subOp.getOperation(); |
398 | |
399 | return printBinaryOperation(emitter, operation, binaryOperator: "-" ); |
400 | } |
401 | |
402 | static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) { |
403 | Operation *operation = cmpOp.getOperation(); |
404 | |
405 | StringRef binaryOperator; |
406 | |
407 | switch (cmpOp.getPredicate()) { |
408 | case emitc::CmpPredicate::eq: |
409 | binaryOperator = "==" ; |
410 | break; |
411 | case emitc::CmpPredicate::ne: |
412 | binaryOperator = "!=" ; |
413 | break; |
414 | case emitc::CmpPredicate::lt: |
415 | binaryOperator = "<" ; |
416 | break; |
417 | case emitc::CmpPredicate::le: |
418 | binaryOperator = "<=" ; |
419 | break; |
420 | case emitc::CmpPredicate::gt: |
421 | binaryOperator = ">" ; |
422 | break; |
423 | case emitc::CmpPredicate::ge: |
424 | binaryOperator = ">=" ; |
425 | break; |
426 | case emitc::CmpPredicate::three_way: |
427 | binaryOperator = "<=>" ; |
428 | break; |
429 | } |
430 | |
431 | return printBinaryOperation(emitter, operation, binaryOperator); |
432 | } |
433 | |
434 | static LogicalResult printOperation(CppEmitter &emitter, |
435 | emitc::VerbatimOp verbatimOp) { |
436 | raw_ostream &os = emitter.ostream(); |
437 | |
438 | os << verbatimOp.getValue(); |
439 | |
440 | return success(); |
441 | } |
442 | |
443 | static LogicalResult printOperation(CppEmitter &emitter, |
444 | cf::BranchOp branchOp) { |
445 | raw_ostream &os = emitter.ostream(); |
446 | Block &successor = *branchOp.getSuccessor(); |
447 | |
448 | for (auto pair : |
449 | llvm::zip(branchOp.getOperands(), successor.getArguments())) { |
450 | Value &operand = std::get<0>(pair); |
451 | BlockArgument &argument = std::get<1>(pair); |
452 | os << emitter.getOrCreateName(argument) << " = " |
453 | << emitter.getOrCreateName(operand) << ";\n" ; |
454 | } |
455 | |
456 | os << "goto " ; |
457 | if (!(emitter.hasBlockLabel(block&: successor))) |
458 | return branchOp.emitOpError("unable to find label for successor block" ); |
459 | os << emitter.getOrCreateName(block&: successor); |
460 | return success(); |
461 | } |
462 | |
463 | static LogicalResult printOperation(CppEmitter &emitter, |
464 | cf::CondBranchOp condBranchOp) { |
465 | raw_indented_ostream &os = emitter.ostream(); |
466 | Block &trueSuccessor = *condBranchOp.getTrueDest(); |
467 | Block &falseSuccessor = *condBranchOp.getFalseDest(); |
468 | |
469 | os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition()) |
470 | << ") {\n" ; |
471 | |
472 | os.indent(); |
473 | |
474 | // If condition is true. |
475 | for (auto pair : llvm::zip(condBranchOp.getTrueOperands(), |
476 | trueSuccessor.getArguments())) { |
477 | Value &operand = std::get<0>(pair); |
478 | BlockArgument &argument = std::get<1>(pair); |
479 | os << emitter.getOrCreateName(argument) << " = " |
480 | << emitter.getOrCreateName(operand) << ";\n" ; |
481 | } |
482 | |
483 | os << "goto " ; |
484 | if (!(emitter.hasBlockLabel(block&: trueSuccessor))) { |
485 | return condBranchOp.emitOpError("unable to find label for successor block" ); |
486 | } |
487 | os << emitter.getOrCreateName(block&: trueSuccessor) << ";\n" ; |
488 | os.unindent() << "} else {\n" ; |
489 | os.indent(); |
490 | // If condition is false. |
491 | for (auto pair : llvm::zip(condBranchOp.getFalseOperands(), |
492 | falseSuccessor.getArguments())) { |
493 | Value &operand = std::get<0>(pair); |
494 | BlockArgument &argument = std::get<1>(pair); |
495 | os << emitter.getOrCreateName(argument) << " = " |
496 | << emitter.getOrCreateName(operand) << ";\n" ; |
497 | } |
498 | |
499 | os << "goto " ; |
500 | if (!(emitter.hasBlockLabel(block&: falseSuccessor))) { |
501 | return condBranchOp.emitOpError() |
502 | << "unable to find label for successor block" ; |
503 | } |
504 | os << emitter.getOrCreateName(block&: falseSuccessor) << ";\n" ; |
505 | os.unindent() << "}" ; |
506 | return success(); |
507 | } |
508 | |
509 | static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp, |
510 | StringRef callee) { |
511 | if (failed(result: emitter.emitAssignPrefix(op&: *callOp))) |
512 | return failure(); |
513 | |
514 | raw_ostream &os = emitter.ostream(); |
515 | os << callee << "(" ; |
516 | if (failed(result: emitter.emitOperands(op&: *callOp))) |
517 | return failure(); |
518 | os << ")" ; |
519 | return success(); |
520 | } |
521 | |
522 | static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) { |
523 | Operation *operation = callOp.getOperation(); |
524 | StringRef callee = callOp.getCallee(); |
525 | |
526 | return printCallOperation(emitter, callOp: operation, callee); |
527 | } |
528 | |
529 | static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) { |
530 | Operation *operation = callOp.getOperation(); |
531 | StringRef callee = callOp.getCallee(); |
532 | |
533 | return printCallOperation(emitter, callOp: operation, callee); |
534 | } |
535 | |
536 | static LogicalResult printOperation(CppEmitter &emitter, |
537 | emitc::CallOpaqueOp callOpaqueOp) { |
538 | raw_ostream &os = emitter.ostream(); |
539 | Operation &op = *callOpaqueOp.getOperation(); |
540 | |
541 | if (failed(result: emitter.emitAssignPrefix(op))) |
542 | return failure(); |
543 | os << callOpaqueOp.getCallee(); |
544 | |
545 | auto emitArgs = [&](Attribute attr) -> LogicalResult { |
546 | if (auto t = dyn_cast<IntegerAttr>(attr)) { |
547 | // Index attributes are treated specially as operand index. |
548 | if (t.getType().isIndex()) { |
549 | int64_t idx = t.getInt(); |
550 | Value operand = op.getOperand(idx); |
551 | auto literalDef = |
552 | dyn_cast_if_present<LiteralOp>(operand.getDefiningOp()); |
553 | if (!literalDef && !emitter.hasValueInScope(val: operand)) |
554 | return op.emitOpError(message: "operand " ) |
555 | << idx << "'s value not defined in scope" ; |
556 | os << emitter.getOrCreateName(val: operand); |
557 | return success(); |
558 | } |
559 | } |
560 | if (failed(result: emitter.emitAttribute(loc: op.getLoc(), attr))) |
561 | return failure(); |
562 | |
563 | return success(); |
564 | }; |
565 | |
566 | if (callOpaqueOp.getTemplateArgs()) { |
567 | os << "<" ; |
568 | if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os, |
569 | emitArgs))) |
570 | return failure(); |
571 | os << ">" ; |
572 | } |
573 | |
574 | os << "(" ; |
575 | |
576 | LogicalResult emittedArgs = |
577 | callOpaqueOp.getArgs() |
578 | ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs) |
579 | : emitter.emitOperands(op); |
580 | if (failed(result: emittedArgs)) |
581 | return failure(); |
582 | os << ")" ; |
583 | return success(); |
584 | } |
585 | |
586 | static LogicalResult printOperation(CppEmitter &emitter, |
587 | emitc::ApplyOp applyOp) { |
588 | raw_ostream &os = emitter.ostream(); |
589 | Operation &op = *applyOp.getOperation(); |
590 | |
591 | if (failed(result: emitter.emitAssignPrefix(op))) |
592 | return failure(); |
593 | os << applyOp.getApplicableOperator(); |
594 | os << emitter.getOrCreateName(applyOp.getOperand()); |
595 | |
596 | return success(); |
597 | } |
598 | |
599 | static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { |
600 | raw_ostream &os = emitter.ostream(); |
601 | Operation &op = *castOp.getOperation(); |
602 | |
603 | if (failed(result: emitter.emitAssignPrefix(op))) |
604 | return failure(); |
605 | os << "(" ; |
606 | if (failed(result: emitter.emitType(loc: op.getLoc(), type: op.getResult(idx: 0).getType()))) |
607 | return failure(); |
608 | os << ") " ; |
609 | return emitter.emitOperand(value: castOp.getOperand()); |
610 | } |
611 | |
612 | static LogicalResult printOperation(CppEmitter &emitter, |
613 | emitc::ExpressionOp expressionOp) { |
614 | if (shouldBeInlined(expressionOp)) |
615 | return success(); |
616 | |
617 | Operation &op = *expressionOp.getOperation(); |
618 | |
619 | if (failed(result: emitter.emitAssignPrefix(op))) |
620 | return failure(); |
621 | |
622 | return emitter.emitExpression(expressionOp); |
623 | } |
624 | |
625 | static LogicalResult printOperation(CppEmitter &emitter, |
626 | emitc::IncludeOp includeOp) { |
627 | raw_ostream &os = emitter.ostream(); |
628 | |
629 | os << "#include " ; |
630 | if (includeOp.getIsStandardInclude()) |
631 | os << "<" << includeOp.getInclude() << ">" ; |
632 | else |
633 | os << "\"" << includeOp.getInclude() << "\"" ; |
634 | |
635 | return success(); |
636 | } |
637 | |
638 | static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { |
639 | |
640 | raw_indented_ostream &os = emitter.ostream(); |
641 | |
642 | // Utility function to determine whether a value is an expression that will be |
643 | // inlined, and as such should be wrapped in parentheses in order to guarantee |
644 | // its precedence and associativity. |
645 | auto requiresParentheses = [&](Value value) { |
646 | auto expressionOp = |
647 | dyn_cast_if_present<ExpressionOp>(value.getDefiningOp()); |
648 | if (!expressionOp) |
649 | return false; |
650 | return shouldBeInlined(expressionOp); |
651 | }; |
652 | |
653 | os << "for (" ; |
654 | if (failed( |
655 | emitter.emitType(loc: forOp.getLoc(), type: forOp.getInductionVar().getType()))) |
656 | return failure(); |
657 | os << " " ; |
658 | os << emitter.getOrCreateName(forOp.getInductionVar()); |
659 | os << " = " ; |
660 | if (failed(emitter.emitOperand(value: forOp.getLowerBound()))) |
661 | return failure(); |
662 | os << "; " ; |
663 | os << emitter.getOrCreateName(forOp.getInductionVar()); |
664 | os << " < " ; |
665 | Value upperBound = forOp.getUpperBound(); |
666 | bool upperBoundRequiresParentheses = requiresParentheses(upperBound); |
667 | if (upperBoundRequiresParentheses) |
668 | os << "(" ; |
669 | if (failed(result: emitter.emitOperand(value: upperBound))) |
670 | return failure(); |
671 | if (upperBoundRequiresParentheses) |
672 | os << ")" ; |
673 | os << "; " ; |
674 | os << emitter.getOrCreateName(forOp.getInductionVar()); |
675 | os << " += " ; |
676 | if (failed(emitter.emitOperand(value: forOp.getStep()))) |
677 | return failure(); |
678 | os << ") {\n" ; |
679 | os.indent(); |
680 | |
681 | Region &forRegion = forOp.getRegion(); |
682 | auto regionOps = forRegion.getOps(); |
683 | |
684 | // We skip the trailing yield op. |
685 | for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) { |
686 | if (failed(emitter.emitOperation(op&: *it, /*trailingSemicolon=*/true))) |
687 | return failure(); |
688 | } |
689 | |
690 | os.unindent() << "}" ; |
691 | |
692 | return success(); |
693 | } |
694 | |
695 | static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) { |
696 | raw_indented_ostream &os = emitter.ostream(); |
697 | |
698 | // Helper function to emit all ops except the last one, expected to be |
699 | // emitc::yield. |
700 | auto emitAllExceptLast = [&emitter](Region ®ion) { |
701 | Region::OpIterator it = region.op_begin(), end = region.op_end(); |
702 | for (; std::next(x: it) != end; ++it) { |
703 | if (failed(result: emitter.emitOperation(op&: *it, /*trailingSemicolon=*/true))) |
704 | return failure(); |
705 | } |
706 | assert(isa<emitc::YieldOp>(*it) && |
707 | "Expected last operation in the region to be emitc::yield" ); |
708 | return success(); |
709 | }; |
710 | |
711 | os << "if (" ; |
712 | if (failed(emitter.emitOperand(value: ifOp.getCondition()))) |
713 | return failure(); |
714 | os << ") {\n" ; |
715 | os.indent(); |
716 | if (failed(emitAllExceptLast(ifOp.getThenRegion()))) |
717 | return failure(); |
718 | os.unindent() << "}" ; |
719 | |
720 | Region &elseRegion = ifOp.getElseRegion(); |
721 | if (!elseRegion.empty()) { |
722 | os << " else {\n" ; |
723 | os.indent(); |
724 | if (failed(result: emitAllExceptLast(elseRegion))) |
725 | return failure(); |
726 | os.unindent() << "}" ; |
727 | } |
728 | |
729 | return success(); |
730 | } |
731 | |
732 | static LogicalResult printOperation(CppEmitter &emitter, |
733 | func::ReturnOp returnOp) { |
734 | raw_ostream &os = emitter.ostream(); |
735 | os << "return" ; |
736 | switch (returnOp.getNumOperands()) { |
737 | case 0: |
738 | return success(); |
739 | case 1: |
740 | os << " " ; |
741 | if (failed(emitter.emitOperand(value: returnOp.getOperand(0)))) |
742 | return failure(); |
743 | return success(); |
744 | default: |
745 | os << " std::make_tuple(" ; |
746 | if (failed(emitter.emitOperandsAndAttributes(op&: *returnOp.getOperation()))) |
747 | return failure(); |
748 | os << ")" ; |
749 | return success(); |
750 | } |
751 | } |
752 | |
753 | static LogicalResult printOperation(CppEmitter &emitter, |
754 | emitc::ReturnOp returnOp) { |
755 | raw_ostream &os = emitter.ostream(); |
756 | os << "return" ; |
757 | if (returnOp.getNumOperands() == 0) |
758 | return success(); |
759 | |
760 | os << " " ; |
761 | if (failed(emitter.emitOperand(value: returnOp.getOperand()))) |
762 | return failure(); |
763 | return success(); |
764 | } |
765 | |
766 | static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { |
767 | CppEmitter::Scope scope(emitter); |
768 | |
769 | for (Operation &op : moduleOp) { |
770 | if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) |
771 | return failure(); |
772 | } |
773 | return success(); |
774 | } |
775 | |
776 | static LogicalResult printFunctionArgs(CppEmitter &emitter, |
777 | Operation *functionOp, |
778 | ArrayRef<Type> arguments) { |
779 | raw_indented_ostream &os = emitter.ostream(); |
780 | |
781 | return ( |
782 | interleaveCommaWithError(c: arguments, os, eachFn: [&](Type arg) -> LogicalResult { |
783 | return emitter.emitType(loc: functionOp->getLoc(), type: arg); |
784 | })); |
785 | } |
786 | |
787 | static LogicalResult printFunctionArgs(CppEmitter &emitter, |
788 | Operation *functionOp, |
789 | Region::BlockArgListType arguments) { |
790 | raw_indented_ostream &os = emitter.ostream(); |
791 | |
792 | return (interleaveCommaWithError( |
793 | c: arguments, os, eachFn: [&](BlockArgument arg) -> LogicalResult { |
794 | if (failed(result: emitter.emitType(loc: functionOp->getLoc(), type: arg.getType()))) |
795 | return failure(); |
796 | os << " " << emitter.getOrCreateName(val: arg); |
797 | return success(); |
798 | })); |
799 | } |
800 | |
801 | static LogicalResult printFunctionBody(CppEmitter &emitter, |
802 | Operation *functionOp, |
803 | Region::BlockListType &blocks) { |
804 | raw_indented_ostream &os = emitter.ostream(); |
805 | os.indent(); |
806 | |
807 | if (emitter.shouldDeclareVariablesAtTop()) { |
808 | // Declare all variables that hold op results including those from nested |
809 | // regions. |
810 | WalkResult result = |
811 | functionOp->walk<WalkOrder::PreOrder>(callback: [&](Operation *op) -> WalkResult { |
812 | if (isa<emitc::LiteralOp>(op) || |
813 | isa<emitc::ExpressionOp>(op->getParentOp()) || |
814 | (isa<emitc::ExpressionOp>(op) && |
815 | shouldBeInlined(cast<emitc::ExpressionOp>(op)))) |
816 | return WalkResult::skip(); |
817 | for (OpResult result : op->getResults()) { |
818 | if (failed(result: emitter.emitVariableDeclaration( |
819 | result, /*trailingSemicolon=*/true))) { |
820 | return WalkResult( |
821 | op->emitError(message: "unable to declare result variable for op" )); |
822 | } |
823 | } |
824 | return WalkResult::advance(); |
825 | }); |
826 | if (result.wasInterrupted()) |
827 | return failure(); |
828 | } |
829 | |
830 | // Create label names for basic blocks. |
831 | for (Block &block : blocks) { |
832 | emitter.getOrCreateName(block); |
833 | } |
834 | |
835 | // Declare variables for basic block arguments. |
836 | for (Block &block : llvm::drop_begin(RangeOrContainer&: blocks)) { |
837 | for (BlockArgument &arg : block.getArguments()) { |
838 | if (emitter.hasValueInScope(val: arg)) |
839 | return functionOp->emitOpError(message: " block argument #" ) |
840 | << arg.getArgNumber() << " is out of scope" ; |
841 | if (failed( |
842 | result: emitter.emitType(loc: block.getParentOp()->getLoc(), type: arg.getType()))) { |
843 | return failure(); |
844 | } |
845 | os << " " << emitter.getOrCreateName(val: arg) << ";\n" ; |
846 | } |
847 | } |
848 | |
849 | for (Block &block : blocks) { |
850 | // Only print a label if the block has predecessors. |
851 | if (!block.hasNoPredecessors()) { |
852 | if (failed(result: emitter.emitLabel(block))) |
853 | return failure(); |
854 | } |
855 | for (Operation &op : block.getOperations()) { |
856 | // When generating code for an emitc.if or cf.cond_br op no semicolon |
857 | // needs to be printed after the closing brace. |
858 | // When generating code for an emitc.for and emitc.verbatim op, printing a |
859 | // trailing semicolon is handled within the printOperation function. |
860 | bool trailingSemicolon = |
861 | !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp, |
862 | emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op); |
863 | |
864 | if (failed(result: emitter.emitOperation( |
865 | op, /*trailingSemicolon=*/trailingSemicolon))) |
866 | return failure(); |
867 | } |
868 | } |
869 | |
870 | os.unindent(); |
871 | |
872 | return success(); |
873 | } |
874 | |
875 | static LogicalResult printOperation(CppEmitter &emitter, |
876 | func::FuncOp functionOp) { |
877 | // We need to declare variables at top if the function has multiple blocks. |
878 | if (!emitter.shouldDeclareVariablesAtTop() && |
879 | functionOp.getBlocks().size() > 1) { |
880 | return functionOp.emitOpError( |
881 | "with multiple blocks needs variables declared at top" ); |
882 | } |
883 | |
884 | CppEmitter::Scope scope(emitter); |
885 | raw_indented_ostream &os = emitter.ostream(); |
886 | if (failed(emitter.emitTypes(loc: functionOp.getLoc(), |
887 | types: functionOp.getFunctionType().getResults()))) |
888 | return failure(); |
889 | os << " " << functionOp.getName(); |
890 | |
891 | os << "(" ; |
892 | Operation *operation = functionOp.getOperation(); |
893 | if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) |
894 | return failure(); |
895 | os << ") {\n" ; |
896 | if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) |
897 | return failure(); |
898 | os << "}\n" ; |
899 | |
900 | return success(); |
901 | } |
902 | |
903 | static LogicalResult printOperation(CppEmitter &emitter, |
904 | emitc::FuncOp functionOp) { |
905 | // We need to declare variables at top if the function has multiple blocks. |
906 | if (!emitter.shouldDeclareVariablesAtTop() && |
907 | functionOp.getBlocks().size() > 1) { |
908 | return functionOp.emitOpError( |
909 | "with multiple blocks needs variables declared at top" ); |
910 | } |
911 | |
912 | CppEmitter::Scope scope(emitter); |
913 | raw_indented_ostream &os = emitter.ostream(); |
914 | if (functionOp.getSpecifiers()) { |
915 | for (Attribute specifier : functionOp.getSpecifiersAttr()) { |
916 | os << cast<StringAttr>(specifier).str() << " " ; |
917 | } |
918 | } |
919 | |
920 | if (failed(emitter.emitTypes(loc: functionOp.getLoc(), |
921 | types: functionOp.getFunctionType().getResults()))) |
922 | return failure(); |
923 | os << " " << functionOp.getName(); |
924 | |
925 | os << "(" ; |
926 | Operation *operation = functionOp.getOperation(); |
927 | if (functionOp.isExternal()) { |
928 | if (failed(printFunctionArgs(emitter, operation, |
929 | functionOp.getArgumentTypes()))) |
930 | return failure(); |
931 | os << ");" ; |
932 | return success(); |
933 | } |
934 | if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) |
935 | return failure(); |
936 | os << ") {\n" ; |
937 | if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks()))) |
938 | return failure(); |
939 | os << "}\n" ; |
940 | |
941 | return success(); |
942 | } |
943 | |
944 | static LogicalResult printOperation(CppEmitter &emitter, |
945 | DeclareFuncOp declareFuncOp) { |
946 | CppEmitter::Scope scope(emitter); |
947 | raw_indented_ostream &os = emitter.ostream(); |
948 | |
949 | auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>( |
950 | declareFuncOp, declareFuncOp.getSymNameAttr()); |
951 | |
952 | if (!functionOp) |
953 | return failure(); |
954 | |
955 | if (functionOp.getSpecifiers()) { |
956 | for (Attribute specifier : functionOp.getSpecifiersAttr()) { |
957 | os << cast<StringAttr>(specifier).str() << " " ; |
958 | } |
959 | } |
960 | |
961 | if (failed(emitter.emitTypes(loc: functionOp.getLoc(), |
962 | types: functionOp.getFunctionType().getResults()))) |
963 | return failure(); |
964 | os << " " << functionOp.getName(); |
965 | |
966 | os << "(" ; |
967 | Operation *operation = functionOp.getOperation(); |
968 | if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments()))) |
969 | return failure(); |
970 | os << ");" ; |
971 | |
972 | return success(); |
973 | } |
974 | |
975 | CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) |
976 | : os(os), declareVariablesAtTop(declareVariablesAtTop) { |
977 | valueInScopeCount.push(x: 0); |
978 | labelInScopeCount.push(x: 0); |
979 | } |
980 | |
981 | /// Return the existing or a new name for a Value. |
982 | StringRef CppEmitter::getOrCreateName(Value val) { |
983 | if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp())) |
984 | return literal.getValue(); |
985 | if (!valueMapper.count(Key: val)) |
986 | valueMapper.insert(Key: val, Val: formatv(Fmt: "v{0}" , Vals&: ++valueInScopeCount.top())); |
987 | return *valueMapper.begin(Key: val); |
988 | } |
989 | |
990 | /// Return the existing or a new label for a Block. |
991 | StringRef CppEmitter::getOrCreateName(Block &block) { |
992 | if (!blockMapper.count(Key: &block)) |
993 | blockMapper.insert(Key: &block, Val: formatv(Fmt: "label{0}" , Vals&: ++labelInScopeCount.top())); |
994 | return *blockMapper.begin(Key: &block); |
995 | } |
996 | |
997 | bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) { |
998 | switch (val) { |
999 | case IntegerType::Signless: |
1000 | return false; |
1001 | case IntegerType::Signed: |
1002 | return false; |
1003 | case IntegerType::Unsigned: |
1004 | return true; |
1005 | } |
1006 | llvm_unreachable("Unexpected IntegerType::SignednessSemantics" ); |
1007 | } |
1008 | |
1009 | bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(Key: val); } |
1010 | |
1011 | bool CppEmitter::hasBlockLabel(Block &block) { |
1012 | return blockMapper.count(Key: &block); |
1013 | } |
1014 | |
1015 | LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { |
1016 | auto printInt = [&](const APInt &val, bool isUnsigned) { |
1017 | if (val.getBitWidth() == 1) { |
1018 | if (val.getBoolValue()) |
1019 | os << "true" ; |
1020 | else |
1021 | os << "false" ; |
1022 | } else { |
1023 | SmallString<128> strValue; |
1024 | val.toString(Str&: strValue, Radix: 10, Signed: !isUnsigned, formatAsCLiteral: false); |
1025 | os << strValue; |
1026 | } |
1027 | }; |
1028 | |
1029 | auto printFloat = [&](const APFloat &val) { |
1030 | if (val.isFinite()) { |
1031 | SmallString<128> strValue; |
1032 | // Use default values of toString except don't truncate zeros. |
1033 | val.toString(Str&: strValue, FormatPrecision: 0, FormatMaxPadding: 0, TruncateZero: false); |
1034 | switch (llvm::APFloatBase::SemanticsToEnum(Sem: val.getSemantics())) { |
1035 | case llvm::APFloatBase::S_IEEEsingle: |
1036 | os << "(float)" ; |
1037 | break; |
1038 | case llvm::APFloatBase::S_IEEEdouble: |
1039 | os << "(double)" ; |
1040 | break; |
1041 | default: |
1042 | break; |
1043 | }; |
1044 | os << strValue; |
1045 | } else if (val.isNaN()) { |
1046 | os << "NAN" ; |
1047 | } else if (val.isInfinity()) { |
1048 | if (val.isNegative()) |
1049 | os << "-" ; |
1050 | os << "INFINITY" ; |
1051 | } |
1052 | }; |
1053 | |
1054 | // Print floating point attributes. |
1055 | if (auto fAttr = dyn_cast<FloatAttr>(attr)) { |
1056 | printFloat(fAttr.getValue()); |
1057 | return success(); |
1058 | } |
1059 | if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) { |
1060 | os << '{'; |
1061 | interleaveComma(c: dense, os, each_fn: [&](const APFloat &val) { printFloat(val); }); |
1062 | os << '}'; |
1063 | return success(); |
1064 | } |
1065 | |
1066 | // Print integer attributes. |
1067 | if (auto iAttr = dyn_cast<IntegerAttr>(attr)) { |
1068 | if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) { |
1069 | printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness())); |
1070 | return success(); |
1071 | } |
1072 | if (auto iType = dyn_cast<IndexType>(iAttr.getType())) { |
1073 | printInt(iAttr.getValue(), false); |
1074 | return success(); |
1075 | } |
1076 | } |
1077 | if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) { |
1078 | if (auto iType = dyn_cast<IntegerType>( |
1079 | cast<TensorType>(dense.getType()).getElementType())) { |
1080 | os << '{'; |
1081 | interleaveComma(c: dense, os, each_fn: [&](const APInt &val) { |
1082 | printInt(val, shouldMapToUnsigned(iType.getSignedness())); |
1083 | }); |
1084 | os << '}'; |
1085 | return success(); |
1086 | } |
1087 | if (auto iType = dyn_cast<IndexType>( |
1088 | cast<TensorType>(dense.getType()).getElementType())) { |
1089 | os << '{'; |
1090 | interleaveComma(c: dense, os, |
1091 | each_fn: [&](const APInt &val) { printInt(val, false); }); |
1092 | os << '}'; |
1093 | return success(); |
1094 | } |
1095 | } |
1096 | |
1097 | // Print opaque attributes. |
1098 | if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) { |
1099 | os << oAttr.getValue(); |
1100 | return success(); |
1101 | } |
1102 | |
1103 | // Print symbolic reference attributes. |
1104 | if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) { |
1105 | if (sAttr.getNestedReferences().size() > 1) |
1106 | return emitError(loc, message: "attribute has more than 1 nested reference" ); |
1107 | os << sAttr.getRootReference().getValue(); |
1108 | return success(); |
1109 | } |
1110 | |
1111 | // Print type attributes. |
1112 | if (auto type = dyn_cast<TypeAttr>(attr)) |
1113 | return emitType(loc, type: type.getValue()); |
1114 | |
1115 | return emitError(loc, message: "cannot emit attribute: " ) << attr; |
1116 | } |
1117 | |
1118 | LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { |
1119 | assert(emittedExpressionPrecedence.empty() && |
1120 | "Expected precedence stack to be empty" ); |
1121 | Operation *rootOp = expressionOp.getRootOp(); |
1122 | |
1123 | emittedExpression = expressionOp; |
1124 | FailureOr<int> precedence = getOperatorPrecedence(operation: rootOp); |
1125 | if (failed(result: precedence)) |
1126 | return failure(); |
1127 | pushExpressionPrecedence(precedence: precedence.value()); |
1128 | |
1129 | if (failed(result: emitOperation(op&: *rootOp, /*trailingSemicolon=*/false))) |
1130 | return failure(); |
1131 | |
1132 | popExpressionPrecedence(); |
1133 | assert(emittedExpressionPrecedence.empty() && |
1134 | "Expected precedence stack to be empty" ); |
1135 | emittedExpression = nullptr; |
1136 | |
1137 | return success(); |
1138 | } |
1139 | |
1140 | LogicalResult CppEmitter::emitOperand(Value value) { |
1141 | if (isPartOfCurrentExpression(value)) { |
1142 | Operation *def = value.getDefiningOp(); |
1143 | assert(def && "Expected operand to be defined by an operation" ); |
1144 | FailureOr<int> precedence = getOperatorPrecedence(operation: def); |
1145 | if (failed(result: precedence)) |
1146 | return failure(); |
1147 | bool encloseInParenthesis = precedence.value() < getExpressionPrecedence(); |
1148 | if (encloseInParenthesis) { |
1149 | os << "(" ; |
1150 | pushExpressionPrecedence(precedence: lowestPrecedence()); |
1151 | } else |
1152 | pushExpressionPrecedence(precedence: precedence.value()); |
1153 | |
1154 | if (failed(result: emitOperation(op&: *def, /*trailingSemicolon=*/false))) |
1155 | return failure(); |
1156 | |
1157 | if (encloseInParenthesis) |
1158 | os << ")" ; |
1159 | |
1160 | popExpressionPrecedence(); |
1161 | return success(); |
1162 | } |
1163 | |
1164 | auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp()); |
1165 | if (expressionOp && shouldBeInlined(expressionOp)) |
1166 | return emitExpression(expressionOp); |
1167 | |
1168 | auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp()); |
1169 | if (!literalOp && !hasValueInScope(val: value)) |
1170 | return failure(); |
1171 | os << getOrCreateName(val: value); |
1172 | return success(); |
1173 | } |
1174 | |
1175 | LogicalResult CppEmitter::emitOperands(Operation &op) { |
1176 | return interleaveCommaWithError(c: op.getOperands(), os, eachFn: [&](Value operand) { |
1177 | // If an expression is being emitted, push lowest precedence as these |
1178 | // operands are either wrapped by parenthesis. |
1179 | if (getEmittedExpression()) |
1180 | pushExpressionPrecedence(precedence: lowestPrecedence()); |
1181 | if (failed(result: emitOperand(value: operand))) |
1182 | return failure(); |
1183 | if (getEmittedExpression()) |
1184 | popExpressionPrecedence(); |
1185 | return success(); |
1186 | }); |
1187 | } |
1188 | |
1189 | LogicalResult |
1190 | CppEmitter::emitOperandsAndAttributes(Operation &op, |
1191 | ArrayRef<StringRef> exclude) { |
1192 | if (failed(result: emitOperands(op))) |
1193 | return failure(); |
1194 | // Insert comma in between operands and non-filtered attributes if needed. |
1195 | if (op.getNumOperands() > 0) { |
1196 | for (NamedAttribute attr : op.getAttrs()) { |
1197 | if (!llvm::is_contained(exclude, attr.getName().strref())) { |
1198 | os << ", " ; |
1199 | break; |
1200 | } |
1201 | } |
1202 | } |
1203 | // Emit attributes. |
1204 | auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { |
1205 | if (llvm::is_contained(exclude, attr.getName().strref())) |
1206 | return success(); |
1207 | os << "/* " << attr.getName().getValue() << " */" ; |
1208 | if (failed(result: emitAttribute(loc: op.getLoc(), attr: attr.getValue()))) |
1209 | return failure(); |
1210 | return success(); |
1211 | }; |
1212 | return interleaveCommaWithError(c: op.getAttrs(), os, eachFn: emitNamedAttribute); |
1213 | } |
1214 | |
1215 | LogicalResult CppEmitter::emitVariableAssignment(OpResult result) { |
1216 | if (!hasValueInScope(val: result)) { |
1217 | return result.getDefiningOp()->emitOpError( |
1218 | message: "result variable for the operation has not been declared" ); |
1219 | } |
1220 | os << getOrCreateName(val: result) << " = " ; |
1221 | return success(); |
1222 | } |
1223 | |
1224 | LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, |
1225 | bool trailingSemicolon) { |
1226 | if (hasValueInScope(val: result)) { |
1227 | return result.getDefiningOp()->emitError( |
1228 | message: "result variable for the operation already declared" ); |
1229 | } |
1230 | if (failed(result: emitType(loc: result.getOwner()->getLoc(), type: result.getType()))) |
1231 | return failure(); |
1232 | os << " " << getOrCreateName(val: result); |
1233 | if (trailingSemicolon) |
1234 | os << ";\n" ; |
1235 | return success(); |
1236 | } |
1237 | |
1238 | LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { |
1239 | // If op is being emitted as part of an expression, bail out. |
1240 | if (getEmittedExpression()) |
1241 | return success(); |
1242 | |
1243 | switch (op.getNumResults()) { |
1244 | case 0: |
1245 | break; |
1246 | case 1: { |
1247 | OpResult result = op.getResult(idx: 0); |
1248 | if (shouldDeclareVariablesAtTop()) { |
1249 | if (failed(result: emitVariableAssignment(result))) |
1250 | return failure(); |
1251 | } else { |
1252 | if (failed(result: emitVariableDeclaration(result, /*trailingSemicolon=*/false))) |
1253 | return failure(); |
1254 | os << " = " ; |
1255 | } |
1256 | break; |
1257 | } |
1258 | default: |
1259 | if (!shouldDeclareVariablesAtTop()) { |
1260 | for (OpResult result : op.getResults()) { |
1261 | if (failed(result: emitVariableDeclaration(result, /*trailingSemicolon=*/true))) |
1262 | return failure(); |
1263 | } |
1264 | } |
1265 | os << "std::tie(" ; |
1266 | interleaveComma(c: op.getResults(), os, |
1267 | each_fn: [&](Value result) { os << getOrCreateName(val: result); }); |
1268 | os << ") = " ; |
1269 | } |
1270 | return success(); |
1271 | } |
1272 | |
1273 | LogicalResult CppEmitter::emitLabel(Block &block) { |
1274 | if (!hasBlockLabel(block)) |
1275 | return block.getParentOp()->emitError(message: "label for block not found" ); |
1276 | // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block |
1277 | // label instead of using `getOStream`. |
1278 | os.getOStream() << getOrCreateName(block) << ":\n" ; |
1279 | return success(); |
1280 | } |
1281 | |
1282 | LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { |
1283 | LogicalResult status = |
1284 | llvm::TypeSwitch<Operation *, LogicalResult>(&op) |
1285 | // Builtin ops. |
1286 | .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); }) |
1287 | // CF ops. |
1288 | .Case<cf::BranchOp, cf::CondBranchOp>( |
1289 | [&](auto op) { return printOperation(*this, op); }) |
1290 | // EmitC ops. |
1291 | .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp, |
1292 | emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp, |
1293 | emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, |
1294 | emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp, |
1295 | emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, |
1296 | emitc::SubOp, emitc::VariableOp, emitc::VerbatimOp>( |
1297 | [&](auto op) { return printOperation(*this, op); }) |
1298 | // Func ops. |
1299 | .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>( |
1300 | [&](auto op) { return printOperation(*this, op); }) |
1301 | // Arithmetic ops. |
1302 | .Case<arith::ConstantOp>( |
1303 | [&](auto op) { return printOperation(*this, op); }) |
1304 | .Case<emitc::LiteralOp>([&](auto op) { return success(); }) |
1305 | .Default([&](Operation *) { |
1306 | return op.emitOpError("unable to find printer for op" ); |
1307 | }); |
1308 | |
1309 | if (failed(result: status)) |
1310 | return failure(); |
1311 | |
1312 | if (isa<emitc::LiteralOp>(op)) |
1313 | return success(); |
1314 | |
1315 | if (getEmittedExpression() || |
1316 | (isa<emitc::ExpressionOp>(op) && |
1317 | shouldBeInlined(cast<emitc::ExpressionOp>(op)))) |
1318 | return success(); |
1319 | |
1320 | os << (trailingSemicolon ? ";\n" : "\n" ); |
1321 | |
1322 | return success(); |
1323 | } |
1324 | |
1325 | LogicalResult CppEmitter::emitType(Location loc, Type type) { |
1326 | if (auto iType = dyn_cast<IntegerType>(type)) { |
1327 | switch (iType.getWidth()) { |
1328 | case 1: |
1329 | return (os << "bool" ), success(); |
1330 | case 8: |
1331 | case 16: |
1332 | case 32: |
1333 | case 64: |
1334 | if (shouldMapToUnsigned(iType.getSignedness())) |
1335 | return (os << "uint" << iType.getWidth() << "_t" ), success(); |
1336 | else |
1337 | return (os << "int" << iType.getWidth() << "_t" ), success(); |
1338 | default: |
1339 | return emitError(loc, message: "cannot emit integer type " ) << type; |
1340 | } |
1341 | } |
1342 | if (auto fType = dyn_cast<FloatType>(Val&: type)) { |
1343 | switch (fType.getWidth()) { |
1344 | case 32: |
1345 | return (os << "float" ), success(); |
1346 | case 64: |
1347 | return (os << "double" ), success(); |
1348 | default: |
1349 | return emitError(loc, message: "cannot emit float type " ) << type; |
1350 | } |
1351 | } |
1352 | if (auto iType = dyn_cast<IndexType>(type)) |
1353 | return (os << "size_t" ), success(); |
1354 | if (auto tType = dyn_cast<TensorType>(Val&: type)) { |
1355 | if (!tType.hasRank()) |
1356 | return emitError(loc, message: "cannot emit unranked tensor type" ); |
1357 | if (!tType.hasStaticShape()) |
1358 | return emitError(loc, message: "cannot emit tensor type with non static shape" ); |
1359 | os << "Tensor<" ; |
1360 | if (failed(result: emitType(loc, type: tType.getElementType()))) |
1361 | return failure(); |
1362 | auto shape = tType.getShape(); |
1363 | for (auto dimSize : shape) { |
1364 | os << ", " ; |
1365 | os << dimSize; |
1366 | } |
1367 | os << ">" ; |
1368 | return success(); |
1369 | } |
1370 | if (auto tType = dyn_cast<TupleType>(type)) |
1371 | return emitTupleType(loc, types: tType.getTypes()); |
1372 | if (auto oType = dyn_cast<emitc::OpaqueType>(type)) { |
1373 | os << oType.getValue(); |
1374 | return success(); |
1375 | } |
1376 | if (auto pType = dyn_cast<emitc::PointerType>(type)) { |
1377 | if (failed(emitType(loc, type: pType.getPointee()))) |
1378 | return failure(); |
1379 | os << "*" ; |
1380 | return success(); |
1381 | } |
1382 | return emitError(loc, message: "cannot emit type " ) << type; |
1383 | } |
1384 | |
1385 | LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) { |
1386 | switch (types.size()) { |
1387 | case 0: |
1388 | os << "void" ; |
1389 | return success(); |
1390 | case 1: |
1391 | return emitType(loc, type: types.front()); |
1392 | default: |
1393 | return emitTupleType(loc, types); |
1394 | } |
1395 | } |
1396 | |
1397 | LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) { |
1398 | os << "std::tuple<" ; |
1399 | if (failed(result: interleaveCommaWithError( |
1400 | c: types, os, eachFn: [&](Type type) { return emitType(loc, type); }))) |
1401 | return failure(); |
1402 | os << ">" ; |
1403 | return success(); |
1404 | } |
1405 | |
1406 | LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, |
1407 | bool declareVariablesAtTop) { |
1408 | CppEmitter emitter(os, declareVariablesAtTop); |
1409 | return emitter.emitOperation(op&: *op, /*trailingSemicolon=*/false); |
1410 | } |
1411 | |