1//===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11#include "mlir/Dialect/EmitC/IR/EmitC.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/Dialect.h"
16#include "mlir/IR/Operation.h"
17#include "mlir/IR/SymbolTable.h"
18#include "mlir/Support/IndentedOstream.h"
19#include "mlir/Support/LLVM.h"
20#include "mlir/Target/Cpp/CppEmitter.h"
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/StringMap.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/FormatVariadic.h"
27#include <utility>
28
29#define DEBUG_TYPE "translate-to-cpp"
30
31using namespace mlir;
32using namespace mlir::emitc;
33using llvm::formatv;
34
35/// Convenience functions to produce interleaved output with functions returning
36/// a LogicalResult. This is different than those in STLExtras as functions used
37/// on each element doesn't return a string.
38template <typename ForwardIterator, typename UnaryFunctor,
39 typename NullaryFunctor>
40inline LogicalResult
41interleaveWithError(ForwardIterator begin, ForwardIterator end,
42 UnaryFunctor eachFn, NullaryFunctor betweenFn) {
43 if (begin == end)
44 return success();
45 if (failed(eachFn(*begin)))
46 return failure();
47 ++begin;
48 for (; begin != end; ++begin) {
49 betweenFn();
50 if (failed(eachFn(*begin)))
51 return failure();
52 }
53 return success();
54}
55
56template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
57inline LogicalResult interleaveWithError(const Container &c,
58 UnaryFunctor eachFn,
59 NullaryFunctor betweenFn) {
60 return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
61}
62
63template <typename Container, typename UnaryFunctor>
64inline LogicalResult interleaveCommaWithError(const Container &c,
65 raw_ostream &os,
66 UnaryFunctor eachFn) {
67 return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
68}
69
70/// Return the precedence of a operator as an integer, higher values
71/// imply higher precedence.
72static FailureOr<int> getOperatorPrecedence(Operation *operation) {
73 return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation)
74 .Case<emitc::AddOp>([&](auto op) { return 11; })
75 .Case<emitc::ApplyOp>([&](auto op) { return 13; })
76 .Case<emitc::CastOp>([&](auto op) { return 13; })
77 .Case<emitc::CmpOp>([&](auto op) -> FailureOr<int> {
78 switch (op.getPredicate()) {
79 case emitc::CmpPredicate::eq:
80 case emitc::CmpPredicate::ne:
81 return 8;
82 case emitc::CmpPredicate::lt:
83 case emitc::CmpPredicate::le:
84 case emitc::CmpPredicate::gt:
85 case emitc::CmpPredicate::ge:
86 return 9;
87 case emitc::CmpPredicate::three_way:
88 return 10;
89 }
90 return op->emitError("unsupported cmp predicate");
91 })
92 .Case<emitc::DivOp>([&](auto op) { return 12; })
93 .Case<emitc::MulOp>([&](auto op) { return 12; })
94 .Case<emitc::RemOp>([&](auto op) { return 12; })
95 .Case<emitc::SubOp>([&](auto op) { return 11; })
96 .Case<emitc::CallOpaqueOp>([&](auto op) { return 14; })
97 .Default([](auto op) { return op->emitError("unsupported operation"); });
98}
99
100namespace {
101/// Emitter that uses dialect specific emitters to emit C++ code.
102struct CppEmitter {
103 explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
104
105 /// Emits attribute or returns failure.
106 LogicalResult emitAttribute(Location loc, Attribute attr);
107
108 /// Emits operation 'op' with/without training semicolon or returns failure.
109 LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
110
111 /// Emits type 'type' or returns failure.
112 LogicalResult emitType(Location loc, Type type);
113
114 /// Emits array of types as a std::tuple of the emitted types.
115 /// - emits void for an empty array;
116 /// - emits the type of the only element for arrays of size one;
117 /// - emits a std::tuple otherwise;
118 LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
119
120 /// Emits array of types as a std::tuple of the emitted types independently of
121 /// the array size.
122 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
123
124 /// Emits an assignment for a variable which has been declared previously.
125 LogicalResult emitVariableAssignment(OpResult result);
126
127 /// Emits a variable declaration for a result of an operation.
128 LogicalResult emitVariableDeclaration(OpResult result,
129 bool trailingSemicolon);
130
131 /// Emits the variable declaration and assignment prefix for 'op'.
132 /// - emits separate variable followed by std::tie for multi-valued operation;
133 /// - emits single type followed by variable for single result;
134 /// - emits nothing if no value produced by op;
135 /// Emits final '=' operator where a type is produced. Returns failure if
136 /// any result type could not be converted.
137 LogicalResult emitAssignPrefix(Operation &op);
138
139 /// Emits a label for the block.
140 LogicalResult emitLabel(Block &block);
141
142 /// Emits the operands and atttributes of the operation. All operands are
143 /// emitted first and then all attributes in alphabetical order.
144 LogicalResult emitOperandsAndAttributes(Operation &op,
145 ArrayRef<StringRef> exclude = {});
146
147 /// Emits the operands of the operation. All operands are emitted in order.
148 LogicalResult emitOperands(Operation &op);
149
150 /// Emits value as an operands of an operation
151 LogicalResult emitOperand(Value value);
152
153 /// Emit an expression as a C expression.
154 LogicalResult emitExpression(ExpressionOp expressionOp);
155
156 /// Return the existing or a new name for a Value.
157 StringRef getOrCreateName(Value val);
158
159 /// Return the existing or a new label of a Block.
160 StringRef getOrCreateName(Block &block);
161
162 /// Whether to map an mlir integer to a unsigned integer in C++.
163 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
164
165 /// RAII helper function to manage entering/exiting C++ scopes.
166 struct Scope {
167 Scope(CppEmitter &emitter)
168 : valueMapperScope(emitter.valueMapper),
169 blockMapperScope(emitter.blockMapper), emitter(emitter) {
170 emitter.valueInScopeCount.push(x: emitter.valueInScopeCount.top());
171 emitter.labelInScopeCount.push(x: emitter.labelInScopeCount.top());
172 }
173 ~Scope() {
174 emitter.valueInScopeCount.pop();
175 emitter.labelInScopeCount.pop();
176 }
177
178 private:
179 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
180 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
181 CppEmitter &emitter;
182 };
183
184 /// Returns wether the Value is assigned to a C++ variable in the scope.
185 bool hasValueInScope(Value val);
186
187 // Returns whether a label is assigned to the block.
188 bool hasBlockLabel(Block &block);
189
190 /// Returns the output stream.
191 raw_indented_ostream &ostream() { return os; };
192
193 /// Returns if all variables for op results and basic block arguments need to
194 /// be declared at the beginning of a function.
195 bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
196
197 /// Get expression currently being emitted.
198 ExpressionOp getEmittedExpression() { return emittedExpression; }
199
200 /// Determine whether given value is part of the expression potentially being
201 /// emitted.
202 bool isPartOfCurrentExpression(Value value) {
203 if (!emittedExpression)
204 return false;
205 Operation *def = value.getDefiningOp();
206 if (!def)
207 return false;
208 auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
209 return operandExpression == emittedExpression;
210 };
211
212private:
213 using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
214 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
215
216 /// Output stream to emit to.
217 raw_indented_ostream os;
218
219 /// Boolean to enforce that all variables for op results and block
220 /// arguments are declared at the beginning of the function. This also
221 /// includes results from ops located in nested regions.
222 bool declareVariablesAtTop;
223
224 /// Map from value to name of C++ variable that contain the name.
225 ValueMapper valueMapper;
226
227 /// Map from block to name of C++ label.
228 BlockMapper blockMapper;
229
230 /// The number of values in the current scope. This is used to declare the
231 /// names of values in a scope.
232 std::stack<int64_t> valueInScopeCount;
233 std::stack<int64_t> labelInScopeCount;
234
235 /// State of the current expression being emitted.
236 ExpressionOp emittedExpression;
237 SmallVector<int> emittedExpressionPrecedence;
238
239 void pushExpressionPrecedence(int precedence) {
240 emittedExpressionPrecedence.push_back(Elt: precedence);
241 }
242 void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
243 static int lowestPrecedence() { return 0; }
244 int getExpressionPrecedence() {
245 if (emittedExpressionPrecedence.empty())
246 return lowestPrecedence();
247 return emittedExpressionPrecedence.back();
248 }
249};
250} // namespace
251
252/// Determine whether expression \p expressionOp should be emitted inline, i.e.
253/// as part of its user. This function recommends inlining of any expressions
254/// that can be inlined unless it is used by another expression, under the
255/// assumption that any expression fusion/re-materialization was taken care of
256/// by transformations run by the backend.
257static bool shouldBeInlined(ExpressionOp expressionOp) {
258 // Do not inline if expression is marked as such.
259 if (expressionOp.getDoNotInline())
260 return false;
261
262 // Do not inline expressions with side effects to prevent side-effect
263 // reordering.
264 if (expressionOp.hasSideEffects())
265 return false;
266
267 // Do not inline expressions with multiple uses.
268 Value result = expressionOp.getResult();
269 if (!result.hasOneUse())
270 return false;
271
272 // Do not inline expressions used by other expressions, as any desired
273 // expression folding was taken care of by transformations.
274 Operation *user = *result.getUsers().begin();
275 return !user->getParentOfType<ExpressionOp>();
276}
277
278static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
279 Attribute value) {
280 OpResult result = operation->getResult(idx: 0);
281
282 // Only emit an assignment as the variable was already declared when printing
283 // the FuncOp.
284 if (emitter.shouldDeclareVariablesAtTop()) {
285 // Skip the assignment if the emitc.constant has no value.
286 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
287 if (oAttr.getValue().empty())
288 return success();
289 }
290
291 if (failed(result: emitter.emitVariableAssignment(result)))
292 return failure();
293 return emitter.emitAttribute(loc: operation->getLoc(), attr: value);
294 }
295
296 // Emit a variable declaration for an emitc.constant op without value.
297 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
298 if (oAttr.getValue().empty())
299 // The semicolon gets printed by the emitOperation function.
300 return emitter.emitVariableDeclaration(result,
301 /*trailingSemicolon=*/false);
302 }
303
304 // Emit a variable declaration.
305 if (failed(result: emitter.emitAssignPrefix(op&: *operation)))
306 return failure();
307 return emitter.emitAttribute(loc: operation->getLoc(), attr: value);
308}
309
310static LogicalResult printOperation(CppEmitter &emitter,
311 emitc::ConstantOp constantOp) {
312 Operation *operation = constantOp.getOperation();
313 Attribute value = constantOp.getValue();
314
315 return printConstantOp(emitter, operation, value);
316}
317
318static LogicalResult printOperation(CppEmitter &emitter,
319 emitc::VariableOp variableOp) {
320 Operation *operation = variableOp.getOperation();
321 Attribute value = variableOp.getValue();
322
323 return printConstantOp(emitter, operation, value);
324}
325
326static LogicalResult printOperation(CppEmitter &emitter,
327 arith::ConstantOp constantOp) {
328 Operation *operation = constantOp.getOperation();
329 Attribute value = constantOp.getValue();
330
331 return printConstantOp(emitter, operation, value);
332}
333
334static LogicalResult printOperation(CppEmitter &emitter,
335 func::ConstantOp constantOp) {
336 Operation *operation = constantOp.getOperation();
337 Attribute value = constantOp.getValueAttr();
338
339 return printConstantOp(emitter, operation, value);
340}
341
342static LogicalResult printOperation(CppEmitter &emitter,
343 emitc::AssignOp assignOp) {
344 auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
345 OpResult result = variableOp->getResult(0);
346
347 if (failed(result: emitter.emitVariableAssignment(result)))
348 return failure();
349
350 return emitter.emitOperand(value: assignOp.getValue());
351}
352
353static LogicalResult printBinaryOperation(CppEmitter &emitter,
354 Operation *operation,
355 StringRef binaryOperator) {
356 raw_ostream &os = emitter.ostream();
357
358 if (failed(result: emitter.emitAssignPrefix(op&: *operation)))
359 return failure();
360
361 if (failed(result: emitter.emitOperand(value: operation->getOperand(idx: 0))))
362 return failure();
363
364 os << " " << binaryOperator << " ";
365
366 if (failed(result: emitter.emitOperand(value: operation->getOperand(idx: 1))))
367 return failure();
368
369 return success();
370}
371
372static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
373 Operation *operation = addOp.getOperation();
374
375 return printBinaryOperation(emitter, operation, binaryOperator: "+");
376}
377
378static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
379 Operation *operation = divOp.getOperation();
380
381 return printBinaryOperation(emitter, operation, binaryOperator: "/");
382}
383
384static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
385 Operation *operation = mulOp.getOperation();
386
387 return printBinaryOperation(emitter, operation, binaryOperator: "*");
388}
389
390static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
391 Operation *operation = remOp.getOperation();
392
393 return printBinaryOperation(emitter, operation, binaryOperator: "%");
394}
395
396static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
397 Operation *operation = subOp.getOperation();
398
399 return printBinaryOperation(emitter, operation, binaryOperator: "-");
400}
401
402static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
403 Operation *operation = cmpOp.getOperation();
404
405 StringRef binaryOperator;
406
407 switch (cmpOp.getPredicate()) {
408 case emitc::CmpPredicate::eq:
409 binaryOperator = "==";
410 break;
411 case emitc::CmpPredicate::ne:
412 binaryOperator = "!=";
413 break;
414 case emitc::CmpPredicate::lt:
415 binaryOperator = "<";
416 break;
417 case emitc::CmpPredicate::le:
418 binaryOperator = "<=";
419 break;
420 case emitc::CmpPredicate::gt:
421 binaryOperator = ">";
422 break;
423 case emitc::CmpPredicate::ge:
424 binaryOperator = ">=";
425 break;
426 case emitc::CmpPredicate::three_way:
427 binaryOperator = "<=>";
428 break;
429 }
430
431 return printBinaryOperation(emitter, operation, binaryOperator);
432}
433
434static LogicalResult printOperation(CppEmitter &emitter,
435 emitc::VerbatimOp verbatimOp) {
436 raw_ostream &os = emitter.ostream();
437
438 os << verbatimOp.getValue();
439
440 return success();
441}
442
443static LogicalResult printOperation(CppEmitter &emitter,
444 cf::BranchOp branchOp) {
445 raw_ostream &os = emitter.ostream();
446 Block &successor = *branchOp.getSuccessor();
447
448 for (auto pair :
449 llvm::zip(branchOp.getOperands(), successor.getArguments())) {
450 Value &operand = std::get<0>(pair);
451 BlockArgument &argument = std::get<1>(pair);
452 os << emitter.getOrCreateName(argument) << " = "
453 << emitter.getOrCreateName(operand) << ";\n";
454 }
455
456 os << "goto ";
457 if (!(emitter.hasBlockLabel(block&: successor)))
458 return branchOp.emitOpError("unable to find label for successor block");
459 os << emitter.getOrCreateName(block&: successor);
460 return success();
461}
462
463static LogicalResult printOperation(CppEmitter &emitter,
464 cf::CondBranchOp condBranchOp) {
465 raw_indented_ostream &os = emitter.ostream();
466 Block &trueSuccessor = *condBranchOp.getTrueDest();
467 Block &falseSuccessor = *condBranchOp.getFalseDest();
468
469 os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
470 << ") {\n";
471
472 os.indent();
473
474 // If condition is true.
475 for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
476 trueSuccessor.getArguments())) {
477 Value &operand = std::get<0>(pair);
478 BlockArgument &argument = std::get<1>(pair);
479 os << emitter.getOrCreateName(argument) << " = "
480 << emitter.getOrCreateName(operand) << ";\n";
481 }
482
483 os << "goto ";
484 if (!(emitter.hasBlockLabel(block&: trueSuccessor))) {
485 return condBranchOp.emitOpError("unable to find label for successor block");
486 }
487 os << emitter.getOrCreateName(block&: trueSuccessor) << ";\n";
488 os.unindent() << "} else {\n";
489 os.indent();
490 // If condition is false.
491 for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
492 falseSuccessor.getArguments())) {
493 Value &operand = std::get<0>(pair);
494 BlockArgument &argument = std::get<1>(pair);
495 os << emitter.getOrCreateName(argument) << " = "
496 << emitter.getOrCreateName(operand) << ";\n";
497 }
498
499 os << "goto ";
500 if (!(emitter.hasBlockLabel(block&: falseSuccessor))) {
501 return condBranchOp.emitOpError()
502 << "unable to find label for successor block";
503 }
504 os << emitter.getOrCreateName(block&: falseSuccessor) << ";\n";
505 os.unindent() << "}";
506 return success();
507}
508
509static LogicalResult printCallOperation(CppEmitter &emitter, Operation *callOp,
510 StringRef callee) {
511 if (failed(result: emitter.emitAssignPrefix(op&: *callOp)))
512 return failure();
513
514 raw_ostream &os = emitter.ostream();
515 os << callee << "(";
516 if (failed(result: emitter.emitOperands(op&: *callOp)))
517 return failure();
518 os << ")";
519 return success();
520}
521
522static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
523 Operation *operation = callOp.getOperation();
524 StringRef callee = callOp.getCallee();
525
526 return printCallOperation(emitter, callOp: operation, callee);
527}
528
529static LogicalResult printOperation(CppEmitter &emitter, emitc::CallOp callOp) {
530 Operation *operation = callOp.getOperation();
531 StringRef callee = callOp.getCallee();
532
533 return printCallOperation(emitter, callOp: operation, callee);
534}
535
536static LogicalResult printOperation(CppEmitter &emitter,
537 emitc::CallOpaqueOp callOpaqueOp) {
538 raw_ostream &os = emitter.ostream();
539 Operation &op = *callOpaqueOp.getOperation();
540
541 if (failed(result: emitter.emitAssignPrefix(op)))
542 return failure();
543 os << callOpaqueOp.getCallee();
544
545 auto emitArgs = [&](Attribute attr) -> LogicalResult {
546 if (auto t = dyn_cast<IntegerAttr>(attr)) {
547 // Index attributes are treated specially as operand index.
548 if (t.getType().isIndex()) {
549 int64_t idx = t.getInt();
550 Value operand = op.getOperand(idx);
551 auto literalDef =
552 dyn_cast_if_present<LiteralOp>(operand.getDefiningOp());
553 if (!literalDef && !emitter.hasValueInScope(val: operand))
554 return op.emitOpError(message: "operand ")
555 << idx << "'s value not defined in scope";
556 os << emitter.getOrCreateName(val: operand);
557 return success();
558 }
559 }
560 if (failed(result: emitter.emitAttribute(loc: op.getLoc(), attr)))
561 return failure();
562
563 return success();
564 };
565
566 if (callOpaqueOp.getTemplateArgs()) {
567 os << "<";
568 if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
569 emitArgs)))
570 return failure();
571 os << ">";
572 }
573
574 os << "(";
575
576 LogicalResult emittedArgs =
577 callOpaqueOp.getArgs()
578 ? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
579 : emitter.emitOperands(op);
580 if (failed(result: emittedArgs))
581 return failure();
582 os << ")";
583 return success();
584}
585
586static LogicalResult printOperation(CppEmitter &emitter,
587 emitc::ApplyOp applyOp) {
588 raw_ostream &os = emitter.ostream();
589 Operation &op = *applyOp.getOperation();
590
591 if (failed(result: emitter.emitAssignPrefix(op)))
592 return failure();
593 os << applyOp.getApplicableOperator();
594 os << emitter.getOrCreateName(applyOp.getOperand());
595
596 return success();
597}
598
599static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
600 raw_ostream &os = emitter.ostream();
601 Operation &op = *castOp.getOperation();
602
603 if (failed(result: emitter.emitAssignPrefix(op)))
604 return failure();
605 os << "(";
606 if (failed(result: emitter.emitType(loc: op.getLoc(), type: op.getResult(idx: 0).getType())))
607 return failure();
608 os << ") ";
609 return emitter.emitOperand(value: castOp.getOperand());
610}
611
612static LogicalResult printOperation(CppEmitter &emitter,
613 emitc::ExpressionOp expressionOp) {
614 if (shouldBeInlined(expressionOp))
615 return success();
616
617 Operation &op = *expressionOp.getOperation();
618
619 if (failed(result: emitter.emitAssignPrefix(op)))
620 return failure();
621
622 return emitter.emitExpression(expressionOp);
623}
624
625static LogicalResult printOperation(CppEmitter &emitter,
626 emitc::IncludeOp includeOp) {
627 raw_ostream &os = emitter.ostream();
628
629 os << "#include ";
630 if (includeOp.getIsStandardInclude())
631 os << "<" << includeOp.getInclude() << ">";
632 else
633 os << "\"" << includeOp.getInclude() << "\"";
634
635 return success();
636}
637
638static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
639
640 raw_indented_ostream &os = emitter.ostream();
641
642 // Utility function to determine whether a value is an expression that will be
643 // inlined, and as such should be wrapped in parentheses in order to guarantee
644 // its precedence and associativity.
645 auto requiresParentheses = [&](Value value) {
646 auto expressionOp =
647 dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
648 if (!expressionOp)
649 return false;
650 return shouldBeInlined(expressionOp);
651 };
652
653 os << "for (";
654 if (failed(
655 emitter.emitType(loc: forOp.getLoc(), type: forOp.getInductionVar().getType())))
656 return failure();
657 os << " ";
658 os << emitter.getOrCreateName(forOp.getInductionVar());
659 os << " = ";
660 if (failed(emitter.emitOperand(value: forOp.getLowerBound())))
661 return failure();
662 os << "; ";
663 os << emitter.getOrCreateName(forOp.getInductionVar());
664 os << " < ";
665 Value upperBound = forOp.getUpperBound();
666 bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
667 if (upperBoundRequiresParentheses)
668 os << "(";
669 if (failed(result: emitter.emitOperand(value: upperBound)))
670 return failure();
671 if (upperBoundRequiresParentheses)
672 os << ")";
673 os << "; ";
674 os << emitter.getOrCreateName(forOp.getInductionVar());
675 os << " += ";
676 if (failed(emitter.emitOperand(value: forOp.getStep())))
677 return failure();
678 os << ") {\n";
679 os.indent();
680
681 Region &forRegion = forOp.getRegion();
682 auto regionOps = forRegion.getOps();
683
684 // We skip the trailing yield op.
685 for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
686 if (failed(emitter.emitOperation(op&: *it, /*trailingSemicolon=*/true)))
687 return failure();
688 }
689
690 os.unindent() << "}";
691
692 return success();
693}
694
695static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
696 raw_indented_ostream &os = emitter.ostream();
697
698 // Helper function to emit all ops except the last one, expected to be
699 // emitc::yield.
700 auto emitAllExceptLast = [&emitter](Region &region) {
701 Region::OpIterator it = region.op_begin(), end = region.op_end();
702 for (; std::next(x: it) != end; ++it) {
703 if (failed(result: emitter.emitOperation(op&: *it, /*trailingSemicolon=*/true)))
704 return failure();
705 }
706 assert(isa<emitc::YieldOp>(*it) &&
707 "Expected last operation in the region to be emitc::yield");
708 return success();
709 };
710
711 os << "if (";
712 if (failed(emitter.emitOperand(value: ifOp.getCondition())))
713 return failure();
714 os << ") {\n";
715 os.indent();
716 if (failed(emitAllExceptLast(ifOp.getThenRegion())))
717 return failure();
718 os.unindent() << "}";
719
720 Region &elseRegion = ifOp.getElseRegion();
721 if (!elseRegion.empty()) {
722 os << " else {\n";
723 os.indent();
724 if (failed(result: emitAllExceptLast(elseRegion)))
725 return failure();
726 os.unindent() << "}";
727 }
728
729 return success();
730}
731
732static LogicalResult printOperation(CppEmitter &emitter,
733 func::ReturnOp returnOp) {
734 raw_ostream &os = emitter.ostream();
735 os << "return";
736 switch (returnOp.getNumOperands()) {
737 case 0:
738 return success();
739 case 1:
740 os << " ";
741 if (failed(emitter.emitOperand(value: returnOp.getOperand(0))))
742 return failure();
743 return success();
744 default:
745 os << " std::make_tuple(";
746 if (failed(emitter.emitOperandsAndAttributes(op&: *returnOp.getOperation())))
747 return failure();
748 os << ")";
749 return success();
750 }
751}
752
753static LogicalResult printOperation(CppEmitter &emitter,
754 emitc::ReturnOp returnOp) {
755 raw_ostream &os = emitter.ostream();
756 os << "return";
757 if (returnOp.getNumOperands() == 0)
758 return success();
759
760 os << " ";
761 if (failed(emitter.emitOperand(value: returnOp.getOperand())))
762 return failure();
763 return success();
764}
765
766static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
767 CppEmitter::Scope scope(emitter);
768
769 for (Operation &op : moduleOp) {
770 if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
771 return failure();
772 }
773 return success();
774}
775
776static LogicalResult printFunctionArgs(CppEmitter &emitter,
777 Operation *functionOp,
778 ArrayRef<Type> arguments) {
779 raw_indented_ostream &os = emitter.ostream();
780
781 return (
782 interleaveCommaWithError(c: arguments, os, eachFn: [&](Type arg) -> LogicalResult {
783 return emitter.emitType(loc: functionOp->getLoc(), type: arg);
784 }));
785}
786
787static LogicalResult printFunctionArgs(CppEmitter &emitter,
788 Operation *functionOp,
789 Region::BlockArgListType arguments) {
790 raw_indented_ostream &os = emitter.ostream();
791
792 return (interleaveCommaWithError(
793 c: arguments, os, eachFn: [&](BlockArgument arg) -> LogicalResult {
794 if (failed(result: emitter.emitType(loc: functionOp->getLoc(), type: arg.getType())))
795 return failure();
796 os << " " << emitter.getOrCreateName(val: arg);
797 return success();
798 }));
799}
800
801static LogicalResult printFunctionBody(CppEmitter &emitter,
802 Operation *functionOp,
803 Region::BlockListType &blocks) {
804 raw_indented_ostream &os = emitter.ostream();
805 os.indent();
806
807 if (emitter.shouldDeclareVariablesAtTop()) {
808 // Declare all variables that hold op results including those from nested
809 // regions.
810 WalkResult result =
811 functionOp->walk<WalkOrder::PreOrder>(callback: [&](Operation *op) -> WalkResult {
812 if (isa<emitc::LiteralOp>(op) ||
813 isa<emitc::ExpressionOp>(op->getParentOp()) ||
814 (isa<emitc::ExpressionOp>(op) &&
815 shouldBeInlined(cast<emitc::ExpressionOp>(op))))
816 return WalkResult::skip();
817 for (OpResult result : op->getResults()) {
818 if (failed(result: emitter.emitVariableDeclaration(
819 result, /*trailingSemicolon=*/true))) {
820 return WalkResult(
821 op->emitError(message: "unable to declare result variable for op"));
822 }
823 }
824 return WalkResult::advance();
825 });
826 if (result.wasInterrupted())
827 return failure();
828 }
829
830 // Create label names for basic blocks.
831 for (Block &block : blocks) {
832 emitter.getOrCreateName(block);
833 }
834
835 // Declare variables for basic block arguments.
836 for (Block &block : llvm::drop_begin(RangeOrContainer&: blocks)) {
837 for (BlockArgument &arg : block.getArguments()) {
838 if (emitter.hasValueInScope(val: arg))
839 return functionOp->emitOpError(message: " block argument #")
840 << arg.getArgNumber() << " is out of scope";
841 if (failed(
842 result: emitter.emitType(loc: block.getParentOp()->getLoc(), type: arg.getType()))) {
843 return failure();
844 }
845 os << " " << emitter.getOrCreateName(val: arg) << ";\n";
846 }
847 }
848
849 for (Block &block : blocks) {
850 // Only print a label if the block has predecessors.
851 if (!block.hasNoPredecessors()) {
852 if (failed(result: emitter.emitLabel(block)))
853 return failure();
854 }
855 for (Operation &op : block.getOperations()) {
856 // When generating code for an emitc.if or cf.cond_br op no semicolon
857 // needs to be printed after the closing brace.
858 // When generating code for an emitc.for and emitc.verbatim op, printing a
859 // trailing semicolon is handled within the printOperation function.
860 bool trailingSemicolon =
861 !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
862 emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
863
864 if (failed(result: emitter.emitOperation(
865 op, /*trailingSemicolon=*/trailingSemicolon)))
866 return failure();
867 }
868 }
869
870 os.unindent();
871
872 return success();
873}
874
875static LogicalResult printOperation(CppEmitter &emitter,
876 func::FuncOp functionOp) {
877 // We need to declare variables at top if the function has multiple blocks.
878 if (!emitter.shouldDeclareVariablesAtTop() &&
879 functionOp.getBlocks().size() > 1) {
880 return functionOp.emitOpError(
881 "with multiple blocks needs variables declared at top");
882 }
883
884 CppEmitter::Scope scope(emitter);
885 raw_indented_ostream &os = emitter.ostream();
886 if (failed(emitter.emitTypes(loc: functionOp.getLoc(),
887 types: functionOp.getFunctionType().getResults())))
888 return failure();
889 os << " " << functionOp.getName();
890
891 os << "(";
892 Operation *operation = functionOp.getOperation();
893 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
894 return failure();
895 os << ") {\n";
896 if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
897 return failure();
898 os << "}\n";
899
900 return success();
901}
902
903static LogicalResult printOperation(CppEmitter &emitter,
904 emitc::FuncOp functionOp) {
905 // We need to declare variables at top if the function has multiple blocks.
906 if (!emitter.shouldDeclareVariablesAtTop() &&
907 functionOp.getBlocks().size() > 1) {
908 return functionOp.emitOpError(
909 "with multiple blocks needs variables declared at top");
910 }
911
912 CppEmitter::Scope scope(emitter);
913 raw_indented_ostream &os = emitter.ostream();
914 if (functionOp.getSpecifiers()) {
915 for (Attribute specifier : functionOp.getSpecifiersAttr()) {
916 os << cast<StringAttr>(specifier).str() << " ";
917 }
918 }
919
920 if (failed(emitter.emitTypes(loc: functionOp.getLoc(),
921 types: functionOp.getFunctionType().getResults())))
922 return failure();
923 os << " " << functionOp.getName();
924
925 os << "(";
926 Operation *operation = functionOp.getOperation();
927 if (functionOp.isExternal()) {
928 if (failed(printFunctionArgs(emitter, operation,
929 functionOp.getArgumentTypes())))
930 return failure();
931 os << ");";
932 return success();
933 }
934 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
935 return failure();
936 os << ") {\n";
937 if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
938 return failure();
939 os << "}\n";
940
941 return success();
942}
943
944static LogicalResult printOperation(CppEmitter &emitter,
945 DeclareFuncOp declareFuncOp) {
946 CppEmitter::Scope scope(emitter);
947 raw_indented_ostream &os = emitter.ostream();
948
949 auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
950 declareFuncOp, declareFuncOp.getSymNameAttr());
951
952 if (!functionOp)
953 return failure();
954
955 if (functionOp.getSpecifiers()) {
956 for (Attribute specifier : functionOp.getSpecifiersAttr()) {
957 os << cast<StringAttr>(specifier).str() << " ";
958 }
959 }
960
961 if (failed(emitter.emitTypes(loc: functionOp.getLoc(),
962 types: functionOp.getFunctionType().getResults())))
963 return failure();
964 os << " " << functionOp.getName();
965
966 os << "(";
967 Operation *operation = functionOp.getOperation();
968 if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
969 return failure();
970 os << ");";
971
972 return success();
973}
974
975CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
976 : os(os), declareVariablesAtTop(declareVariablesAtTop) {
977 valueInScopeCount.push(x: 0);
978 labelInScopeCount.push(x: 0);
979}
980
981/// Return the existing or a new name for a Value.
982StringRef CppEmitter::getOrCreateName(Value val) {
983 if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
984 return literal.getValue();
985 if (!valueMapper.count(Key: val))
986 valueMapper.insert(Key: val, Val: formatv(Fmt: "v{0}", Vals&: ++valueInScopeCount.top()));
987 return *valueMapper.begin(Key: val);
988}
989
990/// Return the existing or a new label for a Block.
991StringRef CppEmitter::getOrCreateName(Block &block) {
992 if (!blockMapper.count(Key: &block))
993 blockMapper.insert(Key: &block, Val: formatv(Fmt: "label{0}", Vals&: ++labelInScopeCount.top()));
994 return *blockMapper.begin(Key: &block);
995}
996
997bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
998 switch (val) {
999 case IntegerType::Signless:
1000 return false;
1001 case IntegerType::Signed:
1002 return false;
1003 case IntegerType::Unsigned:
1004 return true;
1005 }
1006 llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
1007}
1008
1009bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(Key: val); }
1010
1011bool CppEmitter::hasBlockLabel(Block &block) {
1012 return blockMapper.count(Key: &block);
1013}
1014
1015LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
1016 auto printInt = [&](const APInt &val, bool isUnsigned) {
1017 if (val.getBitWidth() == 1) {
1018 if (val.getBoolValue())
1019 os << "true";
1020 else
1021 os << "false";
1022 } else {
1023 SmallString<128> strValue;
1024 val.toString(Str&: strValue, Radix: 10, Signed: !isUnsigned, formatAsCLiteral: false);
1025 os << strValue;
1026 }
1027 };
1028
1029 auto printFloat = [&](const APFloat &val) {
1030 if (val.isFinite()) {
1031 SmallString<128> strValue;
1032 // Use default values of toString except don't truncate zeros.
1033 val.toString(Str&: strValue, FormatPrecision: 0, FormatMaxPadding: 0, TruncateZero: false);
1034 switch (llvm::APFloatBase::SemanticsToEnum(Sem: val.getSemantics())) {
1035 case llvm::APFloatBase::S_IEEEsingle:
1036 os << "(float)";
1037 break;
1038 case llvm::APFloatBase::S_IEEEdouble:
1039 os << "(double)";
1040 break;
1041 default:
1042 break;
1043 };
1044 os << strValue;
1045 } else if (val.isNaN()) {
1046 os << "NAN";
1047 } else if (val.isInfinity()) {
1048 if (val.isNegative())
1049 os << "-";
1050 os << "INFINITY";
1051 }
1052 };
1053
1054 // Print floating point attributes.
1055 if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
1056 printFloat(fAttr.getValue());
1057 return success();
1058 }
1059 if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
1060 os << '{';
1061 interleaveComma(c: dense, os, each_fn: [&](const APFloat &val) { printFloat(val); });
1062 os << '}';
1063 return success();
1064 }
1065
1066 // Print integer attributes.
1067 if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
1068 if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
1069 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
1070 return success();
1071 }
1072 if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
1073 printInt(iAttr.getValue(), false);
1074 return success();
1075 }
1076 }
1077 if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
1078 if (auto iType = dyn_cast<IntegerType>(
1079 cast<TensorType>(dense.getType()).getElementType())) {
1080 os << '{';
1081 interleaveComma(c: dense, os, each_fn: [&](const APInt &val) {
1082 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
1083 });
1084 os << '}';
1085 return success();
1086 }
1087 if (auto iType = dyn_cast<IndexType>(
1088 cast<TensorType>(dense.getType()).getElementType())) {
1089 os << '{';
1090 interleaveComma(c: dense, os,
1091 each_fn: [&](const APInt &val) { printInt(val, false); });
1092 os << '}';
1093 return success();
1094 }
1095 }
1096
1097 // Print opaque attributes.
1098 if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
1099 os << oAttr.getValue();
1100 return success();
1101 }
1102
1103 // Print symbolic reference attributes.
1104 if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
1105 if (sAttr.getNestedReferences().size() > 1)
1106 return emitError(loc, message: "attribute has more than 1 nested reference");
1107 os << sAttr.getRootReference().getValue();
1108 return success();
1109 }
1110
1111 // Print type attributes.
1112 if (auto type = dyn_cast<TypeAttr>(attr))
1113 return emitType(loc, type: type.getValue());
1114
1115 return emitError(loc, message: "cannot emit attribute: ") << attr;
1116}
1117
1118LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
1119 assert(emittedExpressionPrecedence.empty() &&
1120 "Expected precedence stack to be empty");
1121 Operation *rootOp = expressionOp.getRootOp();
1122
1123 emittedExpression = expressionOp;
1124 FailureOr<int> precedence = getOperatorPrecedence(operation: rootOp);
1125 if (failed(result: precedence))
1126 return failure();
1127 pushExpressionPrecedence(precedence: precedence.value());
1128
1129 if (failed(result: emitOperation(op&: *rootOp, /*trailingSemicolon=*/false)))
1130 return failure();
1131
1132 popExpressionPrecedence();
1133 assert(emittedExpressionPrecedence.empty() &&
1134 "Expected precedence stack to be empty");
1135 emittedExpression = nullptr;
1136
1137 return success();
1138}
1139
1140LogicalResult CppEmitter::emitOperand(Value value) {
1141 if (isPartOfCurrentExpression(value)) {
1142 Operation *def = value.getDefiningOp();
1143 assert(def && "Expected operand to be defined by an operation");
1144 FailureOr<int> precedence = getOperatorPrecedence(operation: def);
1145 if (failed(result: precedence))
1146 return failure();
1147 bool encloseInParenthesis = precedence.value() < getExpressionPrecedence();
1148 if (encloseInParenthesis) {
1149 os << "(";
1150 pushExpressionPrecedence(precedence: lowestPrecedence());
1151 } else
1152 pushExpressionPrecedence(precedence: precedence.value());
1153
1154 if (failed(result: emitOperation(op&: *def, /*trailingSemicolon=*/false)))
1155 return failure();
1156
1157 if (encloseInParenthesis)
1158 os << ")";
1159
1160 popExpressionPrecedence();
1161 return success();
1162 }
1163
1164 auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
1165 if (expressionOp && shouldBeInlined(expressionOp))
1166 return emitExpression(expressionOp);
1167
1168 auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
1169 if (!literalOp && !hasValueInScope(val: value))
1170 return failure();
1171 os << getOrCreateName(val: value);
1172 return success();
1173}
1174
1175LogicalResult CppEmitter::emitOperands(Operation &op) {
1176 return interleaveCommaWithError(c: op.getOperands(), os, eachFn: [&](Value operand) {
1177 // If an expression is being emitted, push lowest precedence as these
1178 // operands are either wrapped by parenthesis.
1179 if (getEmittedExpression())
1180 pushExpressionPrecedence(precedence: lowestPrecedence());
1181 if (failed(result: emitOperand(value: operand)))
1182 return failure();
1183 if (getEmittedExpression())
1184 popExpressionPrecedence();
1185 return success();
1186 });
1187}
1188
1189LogicalResult
1190CppEmitter::emitOperandsAndAttributes(Operation &op,
1191 ArrayRef<StringRef> exclude) {
1192 if (failed(result: emitOperands(op)))
1193 return failure();
1194 // Insert comma in between operands and non-filtered attributes if needed.
1195 if (op.getNumOperands() > 0) {
1196 for (NamedAttribute attr : op.getAttrs()) {
1197 if (!llvm::is_contained(exclude, attr.getName().strref())) {
1198 os << ", ";
1199 break;
1200 }
1201 }
1202 }
1203 // Emit attributes.
1204 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
1205 if (llvm::is_contained(exclude, attr.getName().strref()))
1206 return success();
1207 os << "/* " << attr.getName().getValue() << " */";
1208 if (failed(result: emitAttribute(loc: op.getLoc(), attr: attr.getValue())))
1209 return failure();
1210 return success();
1211 };
1212 return interleaveCommaWithError(c: op.getAttrs(), os, eachFn: emitNamedAttribute);
1213}
1214
1215LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
1216 if (!hasValueInScope(val: result)) {
1217 return result.getDefiningOp()->emitOpError(
1218 message: "result variable for the operation has not been declared");
1219 }
1220 os << getOrCreateName(val: result) << " = ";
1221 return success();
1222}
1223
1224LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
1225 bool trailingSemicolon) {
1226 if (hasValueInScope(val: result)) {
1227 return result.getDefiningOp()->emitError(
1228 message: "result variable for the operation already declared");
1229 }
1230 if (failed(result: emitType(loc: result.getOwner()->getLoc(), type: result.getType())))
1231 return failure();
1232 os << " " << getOrCreateName(val: result);
1233 if (trailingSemicolon)
1234 os << ";\n";
1235 return success();
1236}
1237
1238LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
1239 // If op is being emitted as part of an expression, bail out.
1240 if (getEmittedExpression())
1241 return success();
1242
1243 switch (op.getNumResults()) {
1244 case 0:
1245 break;
1246 case 1: {
1247 OpResult result = op.getResult(idx: 0);
1248 if (shouldDeclareVariablesAtTop()) {
1249 if (failed(result: emitVariableAssignment(result)))
1250 return failure();
1251 } else {
1252 if (failed(result: emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
1253 return failure();
1254 os << " = ";
1255 }
1256 break;
1257 }
1258 default:
1259 if (!shouldDeclareVariablesAtTop()) {
1260 for (OpResult result : op.getResults()) {
1261 if (failed(result: emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
1262 return failure();
1263 }
1264 }
1265 os << "std::tie(";
1266 interleaveComma(c: op.getResults(), os,
1267 each_fn: [&](Value result) { os << getOrCreateName(val: result); });
1268 os << ") = ";
1269 }
1270 return success();
1271}
1272
1273LogicalResult CppEmitter::emitLabel(Block &block) {
1274 if (!hasBlockLabel(block))
1275 return block.getParentOp()->emitError(message: "label for block not found");
1276 // FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
1277 // label instead of using `getOStream`.
1278 os.getOStream() << getOrCreateName(block) << ":\n";
1279 return success();
1280}
1281
1282LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1283 LogicalResult status =
1284 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
1285 // Builtin ops.
1286 .Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
1287 // CF ops.
1288 .Case<cf::BranchOp, cf::CondBranchOp>(
1289 [&](auto op) { return printOperation(*this, op); })
1290 // EmitC ops.
1291 .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
1292 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1293 emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
1294 emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
1295 emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1296 emitc::SubOp, emitc::VariableOp, emitc::VerbatimOp>(
1297 [&](auto op) { return printOperation(*this, op); })
1298 // Func ops.
1299 .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
1300 [&](auto op) { return printOperation(*this, op); })
1301 // Arithmetic ops.
1302 .Case<arith::ConstantOp>(
1303 [&](auto op) { return printOperation(*this, op); })
1304 .Case<emitc::LiteralOp>([&](auto op) { return success(); })
1305 .Default([&](Operation *) {
1306 return op.emitOpError("unable to find printer for op");
1307 });
1308
1309 if (failed(result: status))
1310 return failure();
1311
1312 if (isa<emitc::LiteralOp>(op))
1313 return success();
1314
1315 if (getEmittedExpression() ||
1316 (isa<emitc::ExpressionOp>(op) &&
1317 shouldBeInlined(cast<emitc::ExpressionOp>(op))))
1318 return success();
1319
1320 os << (trailingSemicolon ? ";\n" : "\n");
1321
1322 return success();
1323}
1324
1325LogicalResult CppEmitter::emitType(Location loc, Type type) {
1326 if (auto iType = dyn_cast<IntegerType>(type)) {
1327 switch (iType.getWidth()) {
1328 case 1:
1329 return (os << "bool"), success();
1330 case 8:
1331 case 16:
1332 case 32:
1333 case 64:
1334 if (shouldMapToUnsigned(iType.getSignedness()))
1335 return (os << "uint" << iType.getWidth() << "_t"), success();
1336 else
1337 return (os << "int" << iType.getWidth() << "_t"), success();
1338 default:
1339 return emitError(loc, message: "cannot emit integer type ") << type;
1340 }
1341 }
1342 if (auto fType = dyn_cast<FloatType>(Val&: type)) {
1343 switch (fType.getWidth()) {
1344 case 32:
1345 return (os << "float"), success();
1346 case 64:
1347 return (os << "double"), success();
1348 default:
1349 return emitError(loc, message: "cannot emit float type ") << type;
1350 }
1351 }
1352 if (auto iType = dyn_cast<IndexType>(type))
1353 return (os << "size_t"), success();
1354 if (auto tType = dyn_cast<TensorType>(Val&: type)) {
1355 if (!tType.hasRank())
1356 return emitError(loc, message: "cannot emit unranked tensor type");
1357 if (!tType.hasStaticShape())
1358 return emitError(loc, message: "cannot emit tensor type with non static shape");
1359 os << "Tensor<";
1360 if (failed(result: emitType(loc, type: tType.getElementType())))
1361 return failure();
1362 auto shape = tType.getShape();
1363 for (auto dimSize : shape) {
1364 os << ", ";
1365 os << dimSize;
1366 }
1367 os << ">";
1368 return success();
1369 }
1370 if (auto tType = dyn_cast<TupleType>(type))
1371 return emitTupleType(loc, types: tType.getTypes());
1372 if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1373 os << oType.getValue();
1374 return success();
1375 }
1376 if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1377 if (failed(emitType(loc, type: pType.getPointee())))
1378 return failure();
1379 os << "*";
1380 return success();
1381 }
1382 return emitError(loc, message: "cannot emit type ") << type;
1383}
1384
1385LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
1386 switch (types.size()) {
1387 case 0:
1388 os << "void";
1389 return success();
1390 case 1:
1391 return emitType(loc, type: types.front());
1392 default:
1393 return emitTupleType(loc, types);
1394 }
1395}
1396
1397LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1398 os << "std::tuple<";
1399 if (failed(result: interleaveCommaWithError(
1400 c: types, os, eachFn: [&](Type type) { return emitType(loc, type); })))
1401 return failure();
1402 os << ">";
1403 return success();
1404}
1405
1406LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
1407 bool declareVariablesAtTop) {
1408 CppEmitter emitter(os, declareVariablesAtTop);
1409 return emitter.emitOperation(op&: *op, /*trailingSemicolon=*/false);
1410}
1411

source code of mlir/lib/Target/Cpp/TranslateToCpp.cpp