1//===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the MLIR AsmPrinter class, which is used to implement
10// the various print() methods on the core IR objects.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/IR/AffineExpr.h"
15#include "mlir/IR/AffineMap.h"
16#include "mlir/IR/AsmState.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/BuiltinTypeInterfaces.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/Dialect.h"
24#include "mlir/IR/DialectImplementation.h"
25#include "mlir/IR/DialectResourceBlobManager.h"
26#include "mlir/IR/IntegerSet.h"
27#include "mlir/IR/MLIRContext.h"
28#include "mlir/IR/OpImplementation.h"
29#include "mlir/IR/Operation.h"
30#include "mlir/IR/Verifier.h"
31#include "llvm/ADT/APFloat.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/DenseMap.h"
34#include "llvm/ADT/MapVector.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/ScopeExit.h"
37#include "llvm/ADT/ScopedHashTable.h"
38#include "llvm/ADT/SetVector.h"
39#include "llvm/ADT/SmallString.h"
40#include "llvm/ADT/StringExtras.h"
41#include "llvm/ADT/StringSet.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/CommandLine.h"
44#include "llvm/Support/Debug.h"
45#include "llvm/Support/Endian.h"
46#include "llvm/Support/Regex.h"
47#include "llvm/Support/SaveAndRestore.h"
48#include "llvm/Support/Threading.h"
49#include "llvm/Support/raw_ostream.h"
50#include <type_traits>
51
52#include <optional>
53#include <tuple>
54
55using namespace mlir;
56using namespace mlir::detail;
57
58#define DEBUG_TYPE "mlir-asm-printer"
59
60void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
61
62void OperationName::dump() const { print(os&: llvm::errs()); }
63
64//===--------------------------------------------------------------------===//
65// AsmParser
66//===--------------------------------------------------------------------===//
67
68AsmParser::~AsmParser() = default;
69DialectAsmParser::~DialectAsmParser() = default;
70OpAsmParser::~OpAsmParser() = default;
71
72MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
73
74/// Parse a type list.
75/// This is out-of-line to work-around https://github.com/llvm/llvm-project/issues/62918
76ParseResult AsmParser::parseTypeList(SmallVectorImpl<Type> &result) {
77 return parseCommaSeparatedList(
78 parseElementFn: [&]() { return parseType(result&: result.emplace_back()); });
79}
80
81//===----------------------------------------------------------------------===//
82// DialectAsmPrinter
83//===----------------------------------------------------------------------===//
84
85DialectAsmPrinter::~DialectAsmPrinter() = default;
86
87//===----------------------------------------------------------------------===//
88// OpAsmPrinter
89//===----------------------------------------------------------------------===//
90
91OpAsmPrinter::~OpAsmPrinter() = default;
92
93void OpAsmPrinter::printFunctionalType(Operation *op) {
94 auto &os = getStream();
95 os << '(';
96 llvm::interleaveComma(c: op->getOperands(), os, each_fn: [&](Value operand) {
97 // Print the types of null values as <<NULL TYPE>>.
98 *this << (operand ? operand.getType() : Type());
99 });
100 os << ") -> ";
101
102 // Print the result list. We don't parenthesize single result types unless
103 // it is a function (avoiding a grammar ambiguity).
104 bool wrapped = op->getNumResults() != 1;
105 if (!wrapped && op->getResult(idx: 0).getType() &&
106 llvm::isa<FunctionType>(Val: op->getResult(idx: 0).getType()))
107 wrapped = true;
108
109 if (wrapped)
110 os << '(';
111
112 llvm::interleaveComma(c: op->getResults(), os, each_fn: [&](const OpResult &result) {
113 // Print the types of null values as <<NULL TYPE>>.
114 *this << (result ? result.getType() : Type());
115 });
116
117 if (wrapped)
118 os << ')';
119}
120
121//===----------------------------------------------------------------------===//
122// Operation OpAsm interface.
123//===----------------------------------------------------------------------===//
124
125/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
126#include "mlir/IR/OpAsmInterface.cpp.inc"
127
128LogicalResult
129OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
130 return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
131 << "' for dialect '" << getDialect()->getNamespace()
132 << "'";
133}
134
135//===----------------------------------------------------------------------===//
136// OpPrintingFlags
137//===----------------------------------------------------------------------===//
138
139namespace {
140/// This struct contains command line options that can be used to initialize
141/// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
142/// for global command line options.
143struct AsmPrinterOptions {
144 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
145 "mlir-print-elementsattrs-with-hex-if-larger",
146 llvm::cl::desc(
147 "Print DenseElementsAttrs with a hex string that have "
148 "more elements than the given upper limit (use -1 to disable)")};
149
150 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
151 "mlir-elide-elementsattrs-if-larger",
152 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
153 "more elements than the given upper limit")};
154
155 llvm::cl::opt<unsigned> elideResourceStringsIfLarger{
156 "mlir-elide-resource-strings-if-larger",
157 llvm::cl::desc(
158 "Elide printing value of resources if string is too long in chars.")};
159
160 llvm::cl::opt<bool> printDebugInfoOpt{
161 "mlir-print-debuginfo", llvm::cl::init(Val: false),
162 llvm::cl::desc("Print debug info in MLIR output")};
163
164 llvm::cl::opt<bool> printPrettyDebugInfoOpt{
165 "mlir-pretty-debuginfo", llvm::cl::init(Val: false),
166 llvm::cl::desc("Print pretty debug info in MLIR output")};
167
168 // Use the generic op output form in the operation printer even if the custom
169 // form is defined.
170 llvm::cl::opt<bool> printGenericOpFormOpt{
171 "mlir-print-op-generic", llvm::cl::init(Val: false),
172 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
173
174 llvm::cl::opt<bool> assumeVerifiedOpt{
175 "mlir-print-assume-verified", llvm::cl::init(Val: false),
176 llvm::cl::desc("Skip op verification when using custom printers"),
177 llvm::cl::Hidden};
178
179 llvm::cl::opt<bool> printLocalScopeOpt{
180 "mlir-print-local-scope", llvm::cl::init(Val: false),
181 llvm::cl::desc("Print with local scope and inline information (eliding "
182 "aliases for attributes, types, and locations")};
183
184 llvm::cl::opt<bool> skipRegionsOpt{
185 "mlir-print-skip-regions", llvm::cl::init(Val: false),
186 llvm::cl::desc("Skip regions when printing ops.")};
187
188 llvm::cl::opt<bool> printValueUsers{
189 "mlir-print-value-users", llvm::cl::init(Val: false),
190 llvm::cl::desc(
191 "Print users of operation results and block arguments as a comment")};
192};
193} // namespace
194
195static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
196
197/// Register a set of useful command-line options that can be used to configure
198/// various flags within the AsmPrinter.
199void mlir::registerAsmPrinterCLOptions() {
200 // Make sure that the options struct has been initialized.
201 *clOptions;
202}
203
204/// Initialize the printing flags with default supplied by the cl::opts above.
205OpPrintingFlags::OpPrintingFlags()
206 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
207 printGenericOpFormFlag(false), skipRegionsFlag(false),
208 assumeVerifiedFlag(false), printLocalScope(false),
209 printValueUsersFlag(false) {
210 // Initialize based upon command line options, if they are available.
211 if (!clOptions.isConstructed())
212 return;
213 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
214 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
215 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences())
216 elementsAttrHexElementLimit =
217 clOptions->printElementsAttrWithHexIfLarger.getValue();
218 if (clOptions->elideResourceStringsIfLarger.getNumOccurrences())
219 resourceStringCharLimit = clOptions->elideResourceStringsIfLarger;
220 printDebugInfoFlag = clOptions->printDebugInfoOpt;
221 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
222 printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
223 assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
224 printLocalScope = clOptions->printLocalScopeOpt;
225 skipRegionsFlag = clOptions->skipRegionsOpt;
226 printValueUsersFlag = clOptions->printValueUsers;
227}
228
229/// Enable the elision of large elements attributes, by printing a '...'
230/// instead of the element data, when the number of elements is greater than
231/// `largeElementLimit`. Note: The IR generated with this option is not
232/// parsable.
233OpPrintingFlags &
234OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
235 elementsAttrElementLimit = largeElementLimit;
236 return *this;
237}
238
239OpPrintingFlags &
240OpPrintingFlags::printLargeElementsAttrWithHex(int64_t largeElementLimit) {
241 elementsAttrHexElementLimit = largeElementLimit;
242 return *this;
243}
244
245OpPrintingFlags &
246OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) {
247 resourceStringCharLimit = largeResourceLimit;
248 return *this;
249}
250
251/// Enable printing of debug information. If 'prettyForm' is set to true,
252/// debug information is printed in a more readable 'pretty' form.
253OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable,
254 bool prettyForm) {
255 printDebugInfoFlag = enable;
256 printDebugInfoPrettyFormFlag = prettyForm;
257 return *this;
258}
259
260/// Always print operations in the generic form.
261OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) {
262 printGenericOpFormFlag = enable;
263 return *this;
264}
265
266/// Always skip Regions.
267OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) {
268 skipRegionsFlag = skip;
269 return *this;
270}
271
272/// Do not verify the operation when using custom operation printers.
273OpPrintingFlags &OpPrintingFlags::assumeVerified() {
274 assumeVerifiedFlag = true;
275 return *this;
276}
277
278/// Use local scope when printing the operation. This allows for using the
279/// printer in a more localized and thread-safe setting, but may not necessarily
280/// be identical of what the IR will look like when dumping the full module.
281OpPrintingFlags &OpPrintingFlags::useLocalScope() {
282 printLocalScope = true;
283 return *this;
284}
285
286/// Print users of values as comments.
287OpPrintingFlags &OpPrintingFlags::printValueUsers() {
288 printValueUsersFlag = true;
289 return *this;
290}
291
292/// Return if the given ElementsAttr should be elided.
293bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
294 return elementsAttrElementLimit &&
295 *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
296 !llvm::isa<SplatElementsAttr>(attr);
297}
298
299/// Return if the given ElementsAttr should be printed as hex string.
300bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const {
301 // -1 is used to disable hex printing.
302 return (elementsAttrHexElementLimit != -1) &&
303 (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) &&
304 !llvm::isa<SplatElementsAttr>(attr);
305}
306
307/// Return the size limit for printing large ElementsAttr.
308std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
309 return elementsAttrElementLimit;
310}
311
312/// Return the size limit for printing large ElementsAttr as hex string.
313int64_t OpPrintingFlags::getLargeElementsAttrHexLimit() const {
314 return elementsAttrHexElementLimit;
315}
316
317/// Return the size limit for printing large ElementsAttr.
318std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const {
319 return resourceStringCharLimit;
320}
321
322/// Return if debug information should be printed.
323bool OpPrintingFlags::shouldPrintDebugInfo() const {
324 return printDebugInfoFlag;
325}
326
327/// Return if debug information should be printed in the pretty form.
328bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
329 return printDebugInfoPrettyFormFlag;
330}
331
332/// Return if operations should be printed in the generic form.
333bool OpPrintingFlags::shouldPrintGenericOpForm() const {
334 return printGenericOpFormFlag;
335}
336
337/// Return if Region should be skipped.
338bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; }
339
340/// Return if operation verification should be skipped.
341bool OpPrintingFlags::shouldAssumeVerified() const {
342 return assumeVerifiedFlag;
343}
344
345/// Return if the printer should use local scope when dumping the IR.
346bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
347
348/// Return if the printer should print users of values.
349bool OpPrintingFlags::shouldPrintValueUsers() const {
350 return printValueUsersFlag;
351}
352
353//===----------------------------------------------------------------------===//
354// NewLineCounter
355//===----------------------------------------------------------------------===//
356
357namespace {
358/// This class is a simple formatter that emits a new line when inputted into a
359/// stream, that enables counting the number of newlines emitted. This class
360/// should be used whenever emitting newlines in the printer.
361struct NewLineCounter {
362 unsigned curLine = 1;
363};
364
365static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
366 ++newLine.curLine;
367 return os << '\n';
368}
369} // namespace
370
371//===----------------------------------------------------------------------===//
372// AsmPrinter::Impl
373//===----------------------------------------------------------------------===//
374
375namespace mlir {
376class AsmPrinter::Impl {
377public:
378 Impl(raw_ostream &os, AsmStateImpl &state);
379 explicit Impl(Impl &other) : Impl(other.os, other.state) {}
380
381 /// Returns the output stream of the printer.
382 raw_ostream &getStream() { return os; }
383
384 template <typename Container, typename UnaryFunctor>
385 inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
386 llvm::interleaveComma(c, os, eachFn);
387 }
388
389 /// This enum describes the different kinds of elision for the type of an
390 /// attribute when printing it.
391 enum class AttrTypeElision {
392 /// The type must not be elided,
393 Never,
394 /// The type may be elided when it matches the default used in the parser
395 /// (for example i64 is the default for integer attributes).
396 May,
397 /// The type must be elided.
398 Must
399 };
400
401 /// Print the given attribute or an alias.
402 void printAttribute(Attribute attr,
403 AttrTypeElision typeElision = AttrTypeElision::Never);
404 /// Print the given attribute without considering an alias.
405 void printAttributeImpl(Attribute attr,
406 AttrTypeElision typeElision = AttrTypeElision::Never);
407
408 /// Print the alias for the given attribute, return failure if no alias could
409 /// be printed.
410 LogicalResult printAlias(Attribute attr);
411
412 /// Print the given type or an alias.
413 void printType(Type type);
414 /// Print the given type.
415 void printTypeImpl(Type type);
416
417 /// Print the alias for the given type, return failure if no alias could
418 /// be printed.
419 LogicalResult printAlias(Type type);
420
421 /// Print the given location to the stream. If `allowAlias` is true, this
422 /// allows for the internal location to use an attribute alias.
423 void printLocation(LocationAttr loc, bool allowAlias = false);
424
425 /// Print a reference to the given resource that is owned by the given
426 /// dialect.
427 void printResourceHandle(const AsmDialectResourceHandle &resource);
428
429 void printAffineMap(AffineMap map);
430 void
431 printAffineExpr(AffineExpr expr,
432 function_ref<void(unsigned, bool)> printValueName = nullptr);
433 void printAffineConstraint(AffineExpr expr, bool isEq);
434 void printIntegerSet(IntegerSet set);
435
436 LogicalResult pushCyclicPrinting(const void *opaquePointer);
437
438 void popCyclicPrinting();
439
440 void printDimensionList(ArrayRef<int64_t> shape);
441
442protected:
443 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
444 ArrayRef<StringRef> elidedAttrs = {},
445 bool withKeyword = false);
446 void printNamedAttribute(NamedAttribute attr);
447 void printTrailingLocation(Location loc, bool allowAlias = true);
448 void printLocationInternal(LocationAttr loc, bool pretty = false,
449 bool isTopLevel = false);
450
451 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
452 /// used instead of individual elements when the elements attr is large.
453 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
454
455 /// Print a dense string elements attribute.
456 void printDenseStringElementsAttr(DenseStringElementsAttr attr);
457
458 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
459 /// used instead of individual elements when the elements attr is large.
460 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
461 bool allowHex);
462
463 /// Print a dense array attribute.
464 void printDenseArrayAttr(DenseArrayAttr attr);
465
466 void printDialectAttribute(Attribute attr);
467 void printDialectType(Type type);
468
469 /// Print an escaped string, wrapped with "".
470 void printEscapedString(StringRef str);
471
472 /// Print a hex string, wrapped with "".
473 void printHexString(StringRef str);
474 void printHexString(ArrayRef<char> data);
475
476 /// This enum is used to represent the binding strength of the enclosing
477 /// context that an AffineExprStorage is being printed in, so we can
478 /// intelligently produce parens.
479 enum class BindingStrength {
480 Weak, // + and -
481 Strong, // All other binary operators.
482 };
483 void printAffineExprInternal(
484 AffineExpr expr, BindingStrength enclosingTightness,
485 function_ref<void(unsigned, bool)> printValueName = nullptr);
486
487 /// The output stream for the printer.
488 raw_ostream &os;
489
490 /// An underlying assembly printer state.
491 AsmStateImpl &state;
492
493 /// A set of flags to control the printer's behavior.
494 OpPrintingFlags printerFlags;
495
496 /// A tracker for the number of new lines emitted during printing.
497 NewLineCounter newLine;
498};
499} // namespace mlir
500
501//===----------------------------------------------------------------------===//
502// AliasInitializer
503//===----------------------------------------------------------------------===//
504
505namespace {
506/// This class represents a specific instance of a symbol Alias.
507class SymbolAlias {
508public:
509 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType,
510 bool isDeferrable)
511 : name(name), suffixIndex(suffixIndex), isType(isType),
512 isDeferrable(isDeferrable) {}
513
514 /// Print this alias to the given stream.
515 void print(raw_ostream &os) const {
516 os << (isType ? "!" : "#") << name;
517 if (suffixIndex)
518 os << suffixIndex;
519 }
520
521 /// Returns true if this is a type alias.
522 bool isTypeAlias() const { return isType; }
523
524 /// Returns true if this alias supports deferred resolution when parsing.
525 bool canBeDeferred() const { return isDeferrable; }
526
527private:
528 /// The main name of the alias.
529 StringRef name;
530 /// The suffix index of the alias.
531 uint32_t suffixIndex : 30;
532 /// A flag indicating whether this alias is for a type.
533 bool isType : 1;
534 /// A flag indicating whether this alias may be deferred or not.
535 bool isDeferrable : 1;
536};
537
538/// This class represents a utility that initializes the set of attribute and
539/// type aliases, without the need to store the extra information within the
540/// main AliasState class or pass it around via function arguments.
541class AliasInitializer {
542public:
543 AliasInitializer(
544 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
545 llvm::BumpPtrAllocator &aliasAllocator)
546 : interfaces(interfaces), aliasAllocator(aliasAllocator),
547 aliasOS(aliasBuffer) {}
548
549 void initialize(Operation *op, const OpPrintingFlags &printerFlags,
550 llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
551
552 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
553 /// set to true if the originator of this attribute can resolve the alias
554 /// after parsing has completed (e.g. in the case of operation locations).
555 /// `elideType` indicates if the type of the attribute should be skipped when
556 /// looking for nested aliases. Returns the maximum alias depth of the
557 /// attribute, and the alias index of this attribute.
558 std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false,
559 bool elideType = false) {
560 return visitImpl(value: attr, aliases, canBeDeferred, printArgs&: elideType);
561 }
562
563 /// Visit the given type to see if it has an alias. `canBeDeferred` is
564 /// set to true if the originator of this attribute can resolve the alias
565 /// after parsing has completed. Returns the maximum alias depth of the type,
566 /// and the alias index of this type.
567 std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) {
568 return visitImpl(value: type, aliases, canBeDeferred);
569 }
570
571private:
572 struct InProgressAliasInfo {
573 InProgressAliasInfo()
574 : aliasDepth(0), isType(false), canBeDeferred(false) {}
575 InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
576 : alias(alias), aliasDepth(1), isType(isType),
577 canBeDeferred(canBeDeferred) {}
578
579 bool operator<(const InProgressAliasInfo &rhs) const {
580 // Order first by depth, then by attr/type kind, and then by name.
581 if (aliasDepth != rhs.aliasDepth)
582 return aliasDepth < rhs.aliasDepth;
583 if (isType != rhs.isType)
584 return isType;
585 return alias < rhs.alias;
586 }
587
588 /// The alias for the attribute or type, or std::nullopt if the value has no
589 /// alias.
590 std::optional<StringRef> alias;
591 /// The alias depth of this attribute or type, i.e. an indication of the
592 /// relative ordering of when to print this alias.
593 unsigned aliasDepth : 30;
594 /// If this alias represents a type or an attribute.
595 bool isType : 1;
596 /// If this alias can be deferred or not.
597 bool canBeDeferred : 1;
598 /// Indices for child aliases.
599 SmallVector<size_t> childIndices;
600 };
601
602 /// Visit the given attribute or type to see if it has an alias.
603 /// `canBeDeferred` is set to true if the originator of this value can resolve
604 /// the alias after parsing has completed (e.g. in the case of operation
605 /// locations). Returns the maximum alias depth of the value, and its alias
606 /// index.
607 template <typename T, typename... PrintArgs>
608 std::pair<size_t, size_t>
609 visitImpl(T value,
610 llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
611 bool canBeDeferred, PrintArgs &&...printArgs);
612
613 /// Mark the given alias as non-deferrable.
614 void markAliasNonDeferrable(size_t aliasIndex);
615
616 /// Try to generate an alias for the provided symbol. If an alias is
617 /// generated, the provided alias mapping and reverse mapping are updated.
618 template <typename T>
619 void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
620
621 /// Given a collection of aliases and symbols, initialize a mapping from a
622 /// symbol to a given alias.
623 static void initializeAliases(
624 llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
625 llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
626
627 /// The set of asm interfaces within the context.
628 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
629
630 /// An allocator used for alias names.
631 llvm::BumpPtrAllocator &aliasAllocator;
632
633 /// The set of built aliases.
634 llvm::MapVector<const void *, InProgressAliasInfo> aliases;
635
636 /// Storage and stream used when generating an alias.
637 SmallString<32> aliasBuffer;
638 llvm::raw_svector_ostream aliasOS;
639};
640
641/// This class implements a dummy OpAsmPrinter that doesn't print any output,
642/// and merely collects the attributes and types that *would* be printed in a
643/// normal print invocation so that we can generate proper aliases. This allows
644/// for us to generate aliases only for the attributes and types that would be
645/// in the output, and trims down unnecessary output.
646class DummyAliasOperationPrinter : private OpAsmPrinter {
647public:
648 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
649 AliasInitializer &initializer)
650 : printerFlags(printerFlags), initializer(initializer) {}
651
652 /// Prints the entire operation with the custom assembly form, if available,
653 /// or the generic assembly form, otherwise.
654 void printCustomOrGenericOp(Operation *op) override {
655 // Visit the operation location.
656 if (printerFlags.shouldPrintDebugInfo())
657 initializer.visit(attr: op->getLoc(), /*canBeDeferred=*/true);
658
659 // If requested, always print the generic form.
660 if (!printerFlags.shouldPrintGenericOpForm()) {
661 op->getName().printAssembly(op, p&: *this, /*defaultDialect=*/"");
662 return;
663 }
664
665 // Otherwise print with the generic assembly form.
666 printGenericOp(op);
667 }
668
669private:
670 /// Print the given operation in the generic form.
671 void printGenericOp(Operation *op, bool printOpName = true) override {
672 // Consider nested operations for aliases.
673 if (!printerFlags.shouldSkipRegions()) {
674 for (Region &region : op->getRegions())
675 printRegion(region, /*printEntryBlockArgs=*/true,
676 /*printBlockTerminators=*/true);
677 }
678
679 // Visit all the types used in the operation.
680 for (Type type : op->getOperandTypes())
681 printType(type);
682 for (Type type : op->getResultTypes())
683 printType(type);
684
685 // Consider the attributes of the operation for aliases.
686 for (const NamedAttribute &attr : op->getAttrs())
687 printAttribute(attr: attr.getValue());
688 }
689
690 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
691 /// block are not printed. If 'printBlockTerminator' is false, the terminator
692 /// operation of the block is not printed.
693 void print(Block *block, bool printBlockArgs = true,
694 bool printBlockTerminator = true) {
695 // Consider the types of the block arguments for aliases if 'printBlockArgs'
696 // is set to true.
697 if (printBlockArgs) {
698 for (BlockArgument arg : block->getArguments()) {
699 printType(type: arg.getType());
700
701 // Visit the argument location.
702 if (printerFlags.shouldPrintDebugInfo())
703 // TODO: Allow deferring argument locations.
704 initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false);
705 }
706 }
707
708 // Consider the operations within this block, ignoring the terminator if
709 // requested.
710 bool hasTerminator =
711 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
712 auto range = llvm::make_range(
713 x: block->begin(),
714 y: std::prev(x: block->end(),
715 n: (!hasTerminator || printBlockTerminator) ? 0 : 1));
716 for (Operation &op : range)
717 printCustomOrGenericOp(op: &op);
718 }
719
720 /// Print the given region.
721 void printRegion(Region &region, bool printEntryBlockArgs,
722 bool printBlockTerminators,
723 bool printEmptyBlock = false) override {
724 if (region.empty())
725 return;
726 if (printerFlags.shouldSkipRegions()) {
727 os << "{...}";
728 return;
729 }
730
731 auto *entryBlock = &region.front();
732 print(block: entryBlock, printBlockArgs: printEntryBlockArgs, printBlockTerminator: printBlockTerminators);
733 for (Block &b : llvm::drop_begin(RangeOrContainer&: region, N: 1))
734 print(block: &b);
735 }
736
737 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
738 bool omitType) override {
739 printType(type: arg.getType());
740 // Visit the argument location.
741 if (printerFlags.shouldPrintDebugInfo())
742 // TODO: Allow deferring argument locations.
743 initializer.visit(attr: arg.getLoc(), /*canBeDeferred=*/false);
744 }
745
746 /// Consider the given type to be printed for an alias.
747 void printType(Type type) override { initializer.visit(type); }
748
749 /// Consider the given attribute to be printed for an alias.
750 void printAttribute(Attribute attr) override { initializer.visit(attr); }
751 void printAttributeWithoutType(Attribute attr) override {
752 printAttribute(attr);
753 }
754 LogicalResult printAlias(Attribute attr) override {
755 initializer.visit(attr);
756 return success();
757 }
758 LogicalResult printAlias(Type type) override {
759 initializer.visit(type);
760 return success();
761 }
762
763 /// Consider the given location to be printed for an alias.
764 void printOptionalLocationSpecifier(Location loc) override {
765 printAttribute(attr: loc);
766 }
767
768 /// Print the given set of attributes with names not included within
769 /// 'elidedAttrs'.
770 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
771 ArrayRef<StringRef> elidedAttrs = {}) override {
772 if (attrs.empty())
773 return;
774 if (elidedAttrs.empty()) {
775 for (const NamedAttribute &attr : attrs)
776 printAttribute(attr: attr.getValue());
777 return;
778 }
779 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
780 elidedAttrs.end());
781 for (const NamedAttribute &attr : attrs)
782 if (!elidedAttrsSet.contains(V: attr.getName().strref()))
783 printAttribute(attr: attr.getValue());
784 }
785 void printOptionalAttrDictWithKeyword(
786 ArrayRef<NamedAttribute> attrs,
787 ArrayRef<StringRef> elidedAttrs = {}) override {
788 printOptionalAttrDict(attrs, elidedAttrs);
789 }
790
791 /// Return a null stream as the output stream, this will ignore any data fed
792 /// to it.
793 raw_ostream &getStream() const override { return os; }
794
795 /// The following are hooks of `OpAsmPrinter` that are not necessary for
796 /// determining potential aliases.
797 void printFloat(const APFloat &) override {}
798 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
799 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
800 void printNewline() override {}
801 void increaseIndent() override {}
802 void decreaseIndent() override {}
803 void printOperand(Value) override {}
804 void printOperand(Value, raw_ostream &os) override {
805 // Users expect the output string to have at least the prefixed % to signal
806 // a value name. To maintain this invariant, emit a name even if it is
807 // guaranteed to go unused.
808 os << "%";
809 }
810 void printKeywordOrString(StringRef) override {}
811 void printString(StringRef) override {}
812 void printResourceHandle(const AsmDialectResourceHandle &) override {}
813 void printSymbolName(StringRef) override {}
814 void printSuccessor(Block *) override {}
815 void printSuccessorAndUseList(Block *, ValueRange) override {}
816 void shadowRegionArgs(Region &, ValueRange) override {}
817
818 /// The printer flags to use when determining potential aliases.
819 const OpPrintingFlags &printerFlags;
820
821 /// The initializer to use when identifying aliases.
822 AliasInitializer &initializer;
823
824 /// A dummy output stream.
825 mutable llvm::raw_null_ostream os;
826};
827
828class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
829public:
830 explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer,
831 bool canBeDeferred,
832 SmallVectorImpl<size_t> &childIndices)
833 : initializer(initializer), canBeDeferred(canBeDeferred),
834 childIndices(childIndices) {}
835
836 /// Print the given attribute/type, visiting any nested aliases that would be
837 /// generated as part of printing. Returns the maximum alias depth found while
838 /// printing the given value.
839 template <typename T, typename... PrintArgs>
840 size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) {
841 printAndVisitNestedAliasesImpl(value, printArgs...);
842 return maxAliasDepth;
843 }
844
845private:
846 /// Print the given attribute/type, visiting any nested aliases that would be
847 /// generated as part of printing.
848 void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) {
849 if (!isa<BuiltinDialect>(Val: attr.getDialect())) {
850 attr.getDialect().printAttribute(attr, *this);
851
852 // Process the builtin attributes.
853 } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
854 IntegerSetAttr, UnitAttr>(Val: attr)) {
855 return;
856 } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) {
857 printAttribute(attr: distinctAttr.getReferencedAttr());
858 } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) {
859 for (const NamedAttribute &nestedAttr : dictAttr.getValue()) {
860 printAttribute(nestedAttr.getName());
861 printAttribute(nestedAttr.getValue());
862 }
863 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
864 for (Attribute nestedAttr : arrayAttr.getValue())
865 printAttribute(nestedAttr);
866 } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
867 printType(type: typeAttr.getValue());
868 } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) {
869 printAttribute(attr: locAttr.getFallbackLocation());
870 } else if (auto locAttr = dyn_cast<NameLoc>(attr)) {
871 if (!isa<UnknownLoc>(locAttr.getChildLoc()))
872 printAttribute(attr: locAttr.getChildLoc());
873 } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) {
874 printAttribute(attr: locAttr.getCallee());
875 printAttribute(attr: locAttr.getCaller());
876 } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) {
877 if (Attribute metadata = locAttr.getMetadata())
878 printAttribute(attr: metadata);
879 for (Location nestedLoc : locAttr.getLocations())
880 printAttribute(nestedLoc);
881 }
882
883 // Don't print the type if we must elide it, or if it is a None type.
884 if (!elideType) {
885 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
886 Type attrType = typedAttr.getType();
887 if (!llvm::isa<NoneType>(Val: attrType))
888 printType(type: attrType);
889 }
890 }
891 }
892 void printAndVisitNestedAliasesImpl(Type type) {
893 if (!isa<BuiltinDialect>(Val: type.getDialect()))
894 return type.getDialect().printType(type, *this);
895
896 // Only visit the layout of memref if it isn't the identity.
897 if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) {
898 printType(type: memrefTy.getElementType());
899 MemRefLayoutAttrInterface layout = memrefTy.getLayout();
900 if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity())
901 printAttribute(attr: memrefTy.getLayout());
902 if (memrefTy.getMemorySpace())
903 printAttribute(attr: memrefTy.getMemorySpace());
904 return;
905 }
906
907 // For most builtin types, we can simply walk the sub elements.
908 auto visitFn = [&](auto element) {
909 if (element)
910 (void)printAlias(element);
911 };
912 type.walkImmediateSubElements(walkAttrsFn: visitFn, walkTypesFn: visitFn);
913 }
914
915 /// Consider the given type to be printed for an alias.
916 void printType(Type type) override {
917 recordAliasResult(aliasDepthAndIndex: initializer.visit(type, canBeDeferred));
918 }
919
920 /// Consider the given attribute to be printed for an alias.
921 void printAttribute(Attribute attr) override {
922 recordAliasResult(aliasDepthAndIndex: initializer.visit(attr, canBeDeferred));
923 }
924 void printAttributeWithoutType(Attribute attr) override {
925 recordAliasResult(
926 aliasDepthAndIndex: initializer.visit(attr, canBeDeferred, /*elideType=*/true));
927 }
928 LogicalResult printAlias(Attribute attr) override {
929 printAttribute(attr);
930 return success();
931 }
932 LogicalResult printAlias(Type type) override {
933 printType(type);
934 return success();
935 }
936
937 /// Record the alias result of a child element.
938 void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
939 childIndices.push_back(Elt: aliasDepthAndIndex.second);
940 if (aliasDepthAndIndex.first > maxAliasDepth)
941 maxAliasDepth = aliasDepthAndIndex.first;
942 }
943
944 /// Return a null stream as the output stream, this will ignore any data fed
945 /// to it.
946 raw_ostream &getStream() const override { return os; }
947
948 /// The following are hooks of `DialectAsmPrinter` that are not necessary for
949 /// determining potential aliases.
950 void printFloat(const APFloat &) override {}
951 void printKeywordOrString(StringRef) override {}
952 void printString(StringRef) override {}
953 void printSymbolName(StringRef) override {}
954 void printResourceHandle(const AsmDialectResourceHandle &) override {}
955
956 LogicalResult pushCyclicPrinting(const void *opaquePointer) override {
957 return success(isSuccess: cyclicPrintingStack.insert(X: opaquePointer));
958 }
959
960 void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); }
961
962 /// Stack of potentially cyclic mutable attributes or type currently being
963 /// printed.
964 SetVector<const void *> cyclicPrintingStack;
965
966 /// The initializer to use when identifying aliases.
967 AliasInitializer &initializer;
968
969 /// If the aliases visited by this printer can be deferred.
970 bool canBeDeferred;
971
972 /// The indices of child aliases.
973 SmallVectorImpl<size_t> &childIndices;
974
975 /// The maximum alias depth found by the printer.
976 size_t maxAliasDepth = 0;
977
978 /// A dummy output stream.
979 mutable llvm::raw_null_ostream os;
980};
981} // namespace
982
983/// Sanitize the given name such that it can be used as a valid identifier. If
984/// the string needs to be modified in any way, the provided buffer is used to
985/// store the new copy,
986static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
987 StringRef allowedPunctChars = "$._-",
988 bool allowTrailingDigit = true) {
989 assert(!name.empty() && "Shouldn't have an empty name here");
990
991 auto copyNameToBuffer = [&] {
992 for (char ch : name) {
993 if (llvm::isAlnum(C: ch) || allowedPunctChars.contains(C: ch))
994 buffer.push_back(Elt: ch);
995 else if (ch == ' ')
996 buffer.push_back(Elt: '_');
997 else
998 buffer.append(RHS: llvm::utohexstr(X: (unsigned char)ch));
999 }
1000 };
1001
1002 // Check to see if this name is valid. If it starts with a digit, then it
1003 // could conflict with the autogenerated numeric ID's, so add an underscore
1004 // prefix to avoid problems.
1005 if (isdigit(name[0])) {
1006 buffer.push_back(Elt: '_');
1007 copyNameToBuffer();
1008 return buffer;
1009 }
1010
1011 // If the name ends with a trailing digit, add a '_' to avoid potential
1012 // conflicts with autogenerated ID's.
1013 if (!allowTrailingDigit && isdigit(name.back())) {
1014 copyNameToBuffer();
1015 buffer.push_back(Elt: '_');
1016 return buffer;
1017 }
1018
1019 // Check to see that the name consists of only valid identifier characters.
1020 for (char ch : name) {
1021 if (!llvm::isAlnum(C: ch) && !allowedPunctChars.contains(C: ch)) {
1022 copyNameToBuffer();
1023 return buffer;
1024 }
1025 }
1026
1027 // If there are no invalid characters, return the original name.
1028 return name;
1029}
1030
1031/// Given a collection of aliases and symbols, initialize a mapping from a
1032/// symbol to a given alias.
1033void AliasInitializer::initializeAliases(
1034 llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
1035 llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
1036 SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
1037 unprocessedAliases = visitedSymbols.takeVector();
1038 llvm::stable_sort(Range&: unprocessedAliases, C: [](const auto &lhs, const auto &rhs) {
1039 return lhs.second < rhs.second;
1040 });
1041
1042 llvm::StringMap<unsigned> nameCounts;
1043 for (auto &[symbol, aliasInfo] : unprocessedAliases) {
1044 if (!aliasInfo.alias)
1045 continue;
1046 StringRef alias = *aliasInfo.alias;
1047 unsigned nameIndex = nameCounts[alias]++;
1048 symbolToAlias.insert(
1049 KV: {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
1050 aliasInfo.canBeDeferred)});
1051 }
1052}
1053
1054void AliasInitializer::initialize(
1055 Operation *op, const OpPrintingFlags &printerFlags,
1056 llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
1057 // Use a dummy printer when walking the IR so that we can collect the
1058 // attributes/types that will actually be used during printing when
1059 // considering aliases.
1060 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
1061 aliasPrinter.printCustomOrGenericOp(op);
1062
1063 // Initialize the aliases.
1064 initializeAliases(visitedSymbols&: aliases, symbolToAlias&: attrTypeToAlias);
1065}
1066
1067template <typename T, typename... PrintArgs>
1068std::pair<size_t, size_t> AliasInitializer::visitImpl(
1069 T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
1070 bool canBeDeferred, PrintArgs &&...printArgs) {
1071 auto [it, inserted] =
1072 aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()});
1073 size_t aliasIndex = std::distance(aliases.begin(), it);
1074 if (!inserted) {
1075 // Make sure that the alias isn't deferred if we don't permit it.
1076 if (!canBeDeferred)
1077 markAliasNonDeferrable(aliasIndex);
1078 return {static_cast<size_t>(it->second.aliasDepth), aliasIndex};
1079 }
1080
1081 // Try to generate an alias for this value.
1082 generateAlias(value, it->second, canBeDeferred);
1083
1084 // Print the value, capturing any nested elements that require aliases.
1085 SmallVector<size_t> childAliases;
1086 DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases);
1087 size_t maxAliasDepth =
1088 printer.printAndVisitNestedAliases(value, printArgs...);
1089
1090 // Make sure to recompute `it` in case the map was reallocated.
1091 it = std::next(x: aliases.begin(), n: aliasIndex);
1092
1093 // If we had sub elements, update to account for the depth.
1094 it->second.childIndices = std::move(childAliases);
1095 if (maxAliasDepth)
1096 it->second.aliasDepth = maxAliasDepth + 1;
1097
1098 // Propagate the alias depth of the value.
1099 return {(size_t)it->second.aliasDepth, aliasIndex};
1100}
1101
1102void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
1103 auto *it = std::next(x: aliases.begin(), n: aliasIndex);
1104
1105 // If already marked non-deferrable stop the recursion.
1106 // All children should already be marked non-deferrable as well.
1107 if (!it->second.canBeDeferred)
1108 return;
1109
1110 it->second.canBeDeferred = false;
1111
1112 // Propagate the non-deferrable flag to any child aliases.
1113 for (size_t childIndex : it->second.childIndices)
1114 markAliasNonDeferrable(aliasIndex: childIndex);
1115}
1116
1117template <typename T>
1118void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
1119 bool canBeDeferred) {
1120 SmallString<32> nameBuffer;
1121 for (const auto &interface : interfaces) {
1122 OpAsmDialectInterface::AliasResult result =
1123 interface.getAlias(symbol, aliasOS);
1124 if (result == OpAsmDialectInterface::AliasResult::NoAlias)
1125 continue;
1126 nameBuffer = std::move(aliasBuffer);
1127 assert(!nameBuffer.empty() && "expected valid alias name");
1128 if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
1129 break;
1130 }
1131
1132 if (nameBuffer.empty())
1133 return;
1134
1135 SmallString<16> tempBuffer;
1136 StringRef name =
1137 sanitizeIdentifier(name: nameBuffer, buffer&: tempBuffer, /*allowedPunctChars=*/"$_-",
1138 /*allowTrailingDigit=*/false);
1139 name = name.copy(A&: aliasAllocator);
1140 alias = InProgressAliasInfo(name, /*isType=*/std::is_base_of_v<Type, T>,
1141 canBeDeferred);
1142}
1143
1144//===----------------------------------------------------------------------===//
1145// AliasState
1146//===----------------------------------------------------------------------===//
1147
1148namespace {
1149/// This class manages the state for type and attribute aliases.
1150class AliasState {
1151public:
1152 // Initialize the internal aliases.
1153 void
1154 initialize(Operation *op, const OpPrintingFlags &printerFlags,
1155 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
1156
1157 /// Get an alias for the given attribute if it has one and print it in `os`.
1158 /// Returns success if an alias was printed, failure otherwise.
1159 LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
1160
1161 /// Get an alias for the given type if it has one and print it in `os`.
1162 /// Returns success if an alias was printed, failure otherwise.
1163 LogicalResult getAlias(Type ty, raw_ostream &os) const;
1164
1165 /// Print all of the referenced aliases that can not be resolved in a deferred
1166 /// manner.
1167 void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1168 printAliases(p, newLine, /*isDeferred=*/false);
1169 }
1170
1171 /// Print all of the referenced aliases that support deferred resolution.
1172 void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1173 printAliases(p, newLine, /*isDeferred=*/true);
1174 }
1175
1176private:
1177 /// Print all of the referenced aliases that support the provided resolution
1178 /// behavior.
1179 void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1180 bool isDeferred);
1181
1182 /// Mapping between attribute/type and alias.
1183 llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
1184
1185 /// An allocator used for alias names.
1186 llvm::BumpPtrAllocator aliasAllocator;
1187};
1188} // namespace
1189
1190void AliasState::initialize(
1191 Operation *op, const OpPrintingFlags &printerFlags,
1192 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
1193 AliasInitializer initializer(interfaces, aliasAllocator);
1194 initializer.initialize(op, printerFlags, attrTypeToAlias);
1195}
1196
1197LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
1198 const auto *it = attrTypeToAlias.find(Key: attr.getAsOpaquePointer());
1199 if (it == attrTypeToAlias.end())
1200 return failure();
1201 it->second.print(os);
1202 return success();
1203}
1204
1205LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
1206 const auto *it = attrTypeToAlias.find(Key: ty.getAsOpaquePointer());
1207 if (it == attrTypeToAlias.end())
1208 return failure();
1209
1210 it->second.print(os);
1211 return success();
1212}
1213
1214void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1215 bool isDeferred) {
1216 auto filterFn = [=](const auto &aliasIt) {
1217 return aliasIt.second.canBeDeferred() == isDeferred;
1218 };
1219 for (auto &[opaqueSymbol, alias] :
1220 llvm::make_filter_range(Range&: attrTypeToAlias, Pred: filterFn)) {
1221 alias.print(os&: p.getStream());
1222 p.getStream() << " = ";
1223
1224 if (alias.isTypeAlias()) {
1225 // TODO: Support nested aliases in mutable types.
1226 Type type = Type::getFromOpaquePointer(pointer: opaqueSymbol);
1227 if (type.hasTrait<TypeTrait::IsMutable>())
1228 p.getStream() << type;
1229 else
1230 p.printTypeImpl(type);
1231 } else {
1232 // TODO: Support nested aliases in mutable attributes.
1233 Attribute attr = Attribute::getFromOpaquePointer(ptr: opaqueSymbol);
1234 if (attr.hasTrait<AttributeTrait::IsMutable>())
1235 p.getStream() << attr;
1236 else
1237 p.printAttributeImpl(attr);
1238 }
1239
1240 p.getStream() << newLine;
1241 }
1242}
1243
1244//===----------------------------------------------------------------------===//
1245// SSANameState
1246//===----------------------------------------------------------------------===//
1247
1248namespace {
1249/// Info about block printing: a number which is its position in the visitation
1250/// order, and a name that is used to print reference to it, e.g. ^bb42.
1251struct BlockInfo {
1252 int ordering;
1253 StringRef name;
1254};
1255
1256/// This class manages the state of SSA value names.
1257class SSANameState {
1258public:
1259 /// A sentinel value used for values with names set.
1260 enum : unsigned { NameSentinel = ~0U };
1261
1262 SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
1263 SSANameState() = default;
1264
1265 /// Print the SSA identifier for the given value to 'stream'. If
1266 /// 'printResultNo' is true, it also presents the result number ('#' number)
1267 /// of this value.
1268 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
1269
1270 /// Print the operation identifier.
1271 void printOperationID(Operation *op, raw_ostream &stream) const;
1272
1273 /// Return the result indices for each of the result groups registered by this
1274 /// operation, or empty if none exist.
1275 ArrayRef<int> getOpResultGroups(Operation *op);
1276
1277 /// Get the info for the given block.
1278 BlockInfo getBlockInfo(Block *block);
1279
1280 /// Renumber the arguments for the specified region to the same names as the
1281 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
1282 /// details.
1283 void shadowRegionArgs(Region &region, ValueRange namesToUse);
1284
1285private:
1286 /// Number the SSA values within the given IR unit.
1287 void numberValuesInRegion(Region &region);
1288 void numberValuesInBlock(Block &block);
1289 void numberValuesInOp(Operation &op);
1290
1291 /// Given a result of an operation 'result', find the result group head
1292 /// 'lookupValue' and the result of 'result' within that group in
1293 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
1294 /// has more than 1 result.
1295 void getResultIDAndNumber(OpResult result, Value &lookupValue,
1296 std::optional<int> &lookupResultNo) const;
1297
1298 /// Set a special value name for the given value.
1299 void setValueName(Value value, StringRef name);
1300
1301 /// Uniques the given value name within the printer. If the given name
1302 /// conflicts, it is automatically renamed.
1303 StringRef uniqueValueName(StringRef name);
1304
1305 /// This is the value ID for each SSA value. If this returns NameSentinel,
1306 /// then the valueID has an entry in valueNames.
1307 DenseMap<Value, unsigned> valueIDs;
1308 DenseMap<Value, StringRef> valueNames;
1309
1310 /// When printing users of values, an operation without a result might
1311 /// be the user. This map holds ids for such operations.
1312 DenseMap<Operation *, unsigned> operationIDs;
1313
1314 /// This is a map of operations that contain multiple named result groups,
1315 /// i.e. there may be multiple names for the results of the operation. The
1316 /// value of this map are the result numbers that start a result group.
1317 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
1318
1319 /// This maps blocks to there visitation number in the current region as well
1320 /// as the string representing their name.
1321 DenseMap<Block *, BlockInfo> blockNames;
1322
1323 /// This keeps track of all of the non-numeric names that are in flight,
1324 /// allowing us to check for duplicates.
1325 /// Note: the value of the map is unused.
1326 llvm::ScopedHashTable<StringRef, char> usedNames;
1327 llvm::BumpPtrAllocator usedNameAllocator;
1328
1329 /// This is the next value ID to assign in numbering.
1330 unsigned nextValueID = 0;
1331 /// This is the next ID to assign to a region entry block argument.
1332 unsigned nextArgumentID = 0;
1333 /// This is the next ID to assign when a name conflict is detected.
1334 unsigned nextConflictID = 0;
1335
1336 /// These are the printing flags. They control, eg., whether to print in
1337 /// generic form.
1338 OpPrintingFlags printerFlags;
1339};
1340} // namespace
1341
1342SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
1343 : printerFlags(printerFlags) {
1344 llvm::SaveAndRestore valueIDSaver(nextValueID);
1345 llvm::SaveAndRestore argumentIDSaver(nextArgumentID);
1346 llvm::SaveAndRestore conflictIDSaver(nextConflictID);
1347
1348 // The naming context includes `nextValueID`, `nextArgumentID`,
1349 // `nextConflictID` and `usedNames` scoped HashTable. This information is
1350 // carried from the parent region.
1351 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
1352 using NamingContext =
1353 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
1354
1355 // Allocator for UsedNamesScopeTy
1356 llvm::BumpPtrAllocator allocator;
1357
1358 // Add a scope for the top level operation.
1359 auto *topLevelNamesScope =
1360 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
1361
1362 SmallVector<NamingContext, 8> nameContext;
1363 for (Region &region : op->getRegions())
1364 nameContext.push_back(Elt: std::make_tuple(args: &region, args&: nextValueID, args&: nextArgumentID,
1365 args&: nextConflictID, args&: topLevelNamesScope));
1366
1367 numberValuesInOp(op&: *op);
1368
1369 while (!nameContext.empty()) {
1370 Region *region;
1371 UsedNamesScopeTy *parentScope;
1372 std::tie(args&: region, args&: nextValueID, args&: nextArgumentID, args&: nextConflictID, args&: parentScope) =
1373 nameContext.pop_back_val();
1374
1375 // When we switch from one subtree to another, pop the scopes(needless)
1376 // until the parent scope.
1377 while (usedNames.getCurScope() != parentScope) {
1378 usedNames.getCurScope()->~UsedNamesScopeTy();
1379 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
1380 "top level parentScope must be a nullptr");
1381 }
1382
1383 // Add a scope for the current region.
1384 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
1385 UsedNamesScopeTy(usedNames);
1386
1387 numberValuesInRegion(region&: *region);
1388
1389 for (Operation &op : region->getOps())
1390 for (Region &region : op.getRegions())
1391 nameContext.push_back(Elt: std::make_tuple(args: &region, args&: nextValueID,
1392 args&: nextArgumentID, args&: nextConflictID,
1393 args&: curNamesScope));
1394 }
1395
1396 // Manually remove all the scopes.
1397 while (usedNames.getCurScope() != nullptr)
1398 usedNames.getCurScope()->~UsedNamesScopeTy();
1399}
1400
1401void SSANameState::printValueID(Value value, bool printResultNo,
1402 raw_ostream &stream) const {
1403 if (!value) {
1404 stream << "<<NULL VALUE>>";
1405 return;
1406 }
1407
1408 std::optional<int> resultNo;
1409 auto lookupValue = value;
1410
1411 // If this is an operation result, collect the head lookup value of the result
1412 // group and the result number of 'result' within that group.
1413 if (OpResult result = dyn_cast<OpResult>(Val&: value))
1414 getResultIDAndNumber(result, lookupValue, lookupResultNo&: resultNo);
1415
1416 auto it = valueIDs.find(Val: lookupValue);
1417 if (it == valueIDs.end()) {
1418 stream << "<<UNKNOWN SSA VALUE>>";
1419 return;
1420 }
1421
1422 stream << '%';
1423 if (it->second != NameSentinel) {
1424 stream << it->second;
1425 } else {
1426 auto nameIt = valueNames.find(Val: lookupValue);
1427 assert(nameIt != valueNames.end() && "Didn't have a name entry?");
1428 stream << nameIt->second;
1429 }
1430
1431 if (resultNo && printResultNo)
1432 stream << '#' << *resultNo;
1433}
1434
1435void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
1436 auto it = operationIDs.find(Val: op);
1437 if (it == operationIDs.end()) {
1438 stream << "<<UNKNOWN OPERATION>>";
1439 } else {
1440 stream << '%' << it->second;
1441 }
1442}
1443
1444ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
1445 auto it = opResultGroups.find(Val: op);
1446 return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
1447}
1448
1449BlockInfo SSANameState::getBlockInfo(Block *block) {
1450 auto it = blockNames.find(Val: block);
1451 BlockInfo invalidBlock{.ordering: -1, .name: "INVALIDBLOCK"};
1452 return it != blockNames.end() ? it->second : invalidBlock;
1453}
1454
1455void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
1456 assert(!region.empty() && "cannot shadow arguments of an empty region");
1457 assert(region.getNumArguments() == namesToUse.size() &&
1458 "incorrect number of names passed in");
1459 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1460 "only KnownIsolatedFromAbove ops can shadow names");
1461
1462 SmallVector<char, 16> nameStr;
1463 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
1464 auto nameToUse = namesToUse[i];
1465 if (nameToUse == nullptr)
1466 continue;
1467 auto nameToReplace = region.getArgument(i);
1468
1469 nameStr.clear();
1470 llvm::raw_svector_ostream nameStream(nameStr);
1471 printValueID(value: nameToUse, /*printResultNo=*/true, stream&: nameStream);
1472
1473 // Entry block arguments should already have a pretty "arg" name.
1474 assert(valueIDs[nameToReplace] == NameSentinel);
1475
1476 // Use the name without the leading %.
1477 auto name = StringRef(nameStream.str()).drop_front();
1478
1479 // Overwrite the name.
1480 valueNames[nameToReplace] = name.copy(A&: usedNameAllocator);
1481 }
1482}
1483
1484void SSANameState::numberValuesInRegion(Region &region) {
1485 auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1486 assert(!valueIDs.count(arg) && "arg numbered multiple times");
1487 assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
1488 "arg not defined in current region");
1489 setValueName(value: arg, name);
1490 };
1491
1492 if (!printerFlags.shouldPrintGenericOpForm()) {
1493 if (Operation *op = region.getParentOp()) {
1494 if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1495 asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1496 }
1497 }
1498
1499 // Number the values within this region in a breadth-first order.
1500 unsigned nextBlockID = 0;
1501 for (auto &block : region) {
1502 // Each block gets a unique ID, and all of the operations within it get
1503 // numbered as well.
1504 auto blockInfoIt = blockNames.insert(KV: {&block, {.ordering: -1, .name: ""}});
1505 if (blockInfoIt.second) {
1506 // This block hasn't been named through `getAsmBlockArgumentNames`, use
1507 // default `^bbNNN` format.
1508 std::string name;
1509 llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1510 blockInfoIt.first->second.name = StringRef(name).copy(A&: usedNameAllocator);
1511 }
1512 blockInfoIt.first->second.ordering = nextBlockID++;
1513
1514 numberValuesInBlock(block);
1515 }
1516}
1517
1518void SSANameState::numberValuesInBlock(Block &block) {
1519 // Number the block arguments. We give entry block arguments a special name
1520 // 'arg'.
1521 bool isEntryBlock = block.isEntryBlock();
1522 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1523 llvm::raw_svector_ostream specialName(specialNameBuffer);
1524 for (auto arg : block.getArguments()) {
1525 if (valueIDs.count(Val: arg))
1526 continue;
1527 if (isEntryBlock) {
1528 specialNameBuffer.resize(N: strlen(s: "arg"));
1529 specialName << nextArgumentID++;
1530 }
1531 setValueName(value: arg, name: specialName.str());
1532 }
1533
1534 // Number the operations in this block.
1535 for (auto &op : block)
1536 numberValuesInOp(op);
1537}
1538
1539void SSANameState::numberValuesInOp(Operation &op) {
1540 // Function used to set the special result names for the operation.
1541 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1542 auto setResultNameFn = [&](Value result, StringRef name) {
1543 assert(!valueIDs.count(result) && "result numbered multiple times");
1544 assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1545 setValueName(value: result, name);
1546
1547 // Record the result number for groups not anchored at 0.
1548 if (int resultNo = llvm::cast<OpResult>(Val&: result).getResultNumber())
1549 resultGroups.push_back(Elt: resultNo);
1550 };
1551 // Operations can customize the printing of block names in OpAsmOpInterface.
1552 auto setBlockNameFn = [&](Block *block, StringRef name) {
1553 assert(block->getParentOp() == &op &&
1554 "getAsmBlockArgumentNames callback invoked on a block not directly "
1555 "nested under the current operation");
1556 assert(!blockNames.count(block) && "block numbered multiple times");
1557 SmallString<16> tmpBuffer{"^"};
1558 name = sanitizeIdentifier(name, buffer&: tmpBuffer);
1559 if (name.data() != tmpBuffer.data()) {
1560 tmpBuffer.append(RHS: name);
1561 name = tmpBuffer.str();
1562 }
1563 name = name.copy(A&: usedNameAllocator);
1564 blockNames[block] = {.ordering: -1, .name: name};
1565 };
1566
1567 if (!printerFlags.shouldPrintGenericOpForm()) {
1568 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1569 asmInterface.getAsmBlockNames(setBlockNameFn);
1570 asmInterface.getAsmResultNames(setResultNameFn);
1571 }
1572 }
1573
1574 unsigned numResults = op.getNumResults();
1575 if (numResults == 0) {
1576 // If value users should be printed, operations with no result need an id.
1577 if (printerFlags.shouldPrintValueUsers()) {
1578 if (operationIDs.try_emplace(Key: &op, Args&: nextValueID).second)
1579 ++nextValueID;
1580 }
1581 return;
1582 }
1583 Value resultBegin = op.getResult(idx: 0);
1584
1585 // If the first result wasn't numbered, give it a default number.
1586 if (valueIDs.try_emplace(Key: resultBegin, Args&: nextValueID).second)
1587 ++nextValueID;
1588
1589 // If this operation has multiple result groups, mark it.
1590 if (resultGroups.size() != 1) {
1591 llvm::array_pod_sort(Start: resultGroups.begin(), End: resultGroups.end());
1592 opResultGroups.try_emplace(Key: &op, Args: std::move(resultGroups));
1593 }
1594}
1595
1596void SSANameState::getResultIDAndNumber(
1597 OpResult result, Value &lookupValue,
1598 std::optional<int> &lookupResultNo) const {
1599 Operation *owner = result.getOwner();
1600 if (owner->getNumResults() == 1)
1601 return;
1602 int resultNo = result.getResultNumber();
1603
1604 // If this operation has multiple result groups, we will need to find the
1605 // one corresponding to this result.
1606 auto resultGroupIt = opResultGroups.find(Val: owner);
1607 if (resultGroupIt == opResultGroups.end()) {
1608 // If not, just use the first result.
1609 lookupResultNo = resultNo;
1610 lookupValue = owner->getResult(idx: 0);
1611 return;
1612 }
1613
1614 // Find the correct index using a binary search, as the groups are ordered.
1615 ArrayRef<int> resultGroups = resultGroupIt->second;
1616 const auto *it = llvm::upper_bound(Range&: resultGroups, Value&: resultNo);
1617 int groupResultNo = 0, groupSize = 0;
1618
1619 // If there are no smaller elements, the last result group is the lookup.
1620 if (it == resultGroups.end()) {
1621 groupResultNo = resultGroups.back();
1622 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1623 } else {
1624 // Otherwise, the previous element is the lookup.
1625 groupResultNo = *std::prev(x: it);
1626 groupSize = *it - groupResultNo;
1627 }
1628
1629 // We only record the result number for a group of size greater than 1.
1630 if (groupSize != 1)
1631 lookupResultNo = resultNo - groupResultNo;
1632 lookupValue = owner->getResult(idx: groupResultNo);
1633}
1634
1635void SSANameState::setValueName(Value value, StringRef name) {
1636 // If the name is empty, the value uses the default numbering.
1637 if (name.empty()) {
1638 valueIDs[value] = nextValueID++;
1639 return;
1640 }
1641
1642 valueIDs[value] = NameSentinel;
1643 valueNames[value] = uniqueValueName(name);
1644}
1645
1646StringRef SSANameState::uniqueValueName(StringRef name) {
1647 SmallString<16> tmpBuffer;
1648 name = sanitizeIdentifier(name, buffer&: tmpBuffer);
1649
1650 // Check to see if this name is already unique.
1651 if (!usedNames.count(Key: name)) {
1652 name = name.copy(A&: usedNameAllocator);
1653 } else {
1654 // Otherwise, we had a conflict - probe until we find a unique name. This
1655 // is guaranteed to terminate (and usually in a single iteration) because it
1656 // generates new names by incrementing nextConflictID.
1657 SmallString<64> probeName(name);
1658 probeName.push_back(Elt: '_');
1659 while (true) {
1660 probeName += llvm::utostr(X: nextConflictID++);
1661 if (!usedNames.count(Key: probeName)) {
1662 name = probeName.str().copy(A&: usedNameAllocator);
1663 break;
1664 }
1665 probeName.resize(N: name.size() + 1);
1666 }
1667 }
1668
1669 usedNames.insert(Key: name, Val: char());
1670 return name;
1671}
1672
1673//===----------------------------------------------------------------------===//
1674// DistinctState
1675//===----------------------------------------------------------------------===//
1676
1677namespace {
1678/// This class manages the state for distinct attributes.
1679class DistinctState {
1680public:
1681 /// Returns a unique identifier for the given distinct attribute.
1682 uint64_t getId(DistinctAttr distinctAttr);
1683
1684private:
1685 uint64_t distinctCounter = 0;
1686 DenseMap<DistinctAttr, uint64_t> distinctAttrMap;
1687};
1688} // namespace
1689
1690uint64_t DistinctState::getId(DistinctAttr distinctAttr) {
1691 auto [it, inserted] =
1692 distinctAttrMap.try_emplace(Key: distinctAttr, Args&: distinctCounter);
1693 if (inserted)
1694 distinctCounter++;
1695 return it->getSecond();
1696}
1697
1698//===----------------------------------------------------------------------===//
1699// Resources
1700//===----------------------------------------------------------------------===//
1701
1702AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
1703AsmResourceBuilder::~AsmResourceBuilder() = default;
1704AsmResourceParser::~AsmResourceParser() = default;
1705AsmResourcePrinter::~AsmResourcePrinter() = default;
1706
1707StringRef mlir::toString(AsmResourceEntryKind kind) {
1708 switch (kind) {
1709 case AsmResourceEntryKind::Blob:
1710 return "blob";
1711 case AsmResourceEntryKind::Bool:
1712 return "bool";
1713 case AsmResourceEntryKind::String:
1714 return "string";
1715 }
1716 llvm_unreachable("unknown AsmResourceEntryKind");
1717}
1718
1719AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
1720 std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
1721 if (!collection)
1722 collection = std::make_unique<ResourceCollection>(args&: key);
1723 return *collection;
1724}
1725
1726std::vector<std::unique_ptr<AsmResourcePrinter>>
1727FallbackAsmResourceMap::getPrinters() {
1728 std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
1729 for (auto &it : keyToResources) {
1730 ResourceCollection *collection = it.second.get();
1731 auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
1732 return collection->buildResources(op, builder);
1733 };
1734 printers.emplace_back(
1735 args: AsmResourcePrinter::fromCallable(name: collection->getName(), printFn&: buildValues));
1736 }
1737 return printers;
1738}
1739
1740LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
1741 AsmParsedResourceEntry &entry) {
1742 switch (entry.getKind()) {
1743 case AsmResourceEntryKind::Blob: {
1744 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
1745 if (failed(result: blob))
1746 return failure();
1747 resources.emplace_back(Args: entry.getKey(), Args: std::move(*blob));
1748 return success();
1749 }
1750 case AsmResourceEntryKind::Bool: {
1751 FailureOr<bool> value = entry.parseAsBool();
1752 if (failed(result: value))
1753 return failure();
1754 resources.emplace_back(Args: entry.getKey(), Args&: *value);
1755 break;
1756 }
1757 case AsmResourceEntryKind::String: {
1758 FailureOr<std::string> str = entry.parseAsString();
1759 if (failed(result: str))
1760 return failure();
1761 resources.emplace_back(Args: entry.getKey(), Args: std::move(*str));
1762 break;
1763 }
1764 }
1765 return success();
1766}
1767
1768void FallbackAsmResourceMap::ResourceCollection::buildResources(
1769 Operation *op, AsmResourceBuilder &builder) const {
1770 for (const auto &entry : resources) {
1771 if (const auto *value = std::get_if<AsmResourceBlob>(ptr: &entry.value))
1772 builder.buildBlob(key: entry.key, blob: *value);
1773 else if (const auto *value = std::get_if<bool>(ptr: &entry.value))
1774 builder.buildBool(key: entry.key, data: *value);
1775 else if (const auto *value = std::get_if<std::string>(ptr: &entry.value))
1776 builder.buildString(key: entry.key, data: *value);
1777 else
1778 llvm_unreachable("unknown AsmResourceEntryKind");
1779 }
1780}
1781
1782//===----------------------------------------------------------------------===//
1783// AsmState
1784//===----------------------------------------------------------------------===//
1785
1786namespace mlir {
1787namespace detail {
1788class AsmStateImpl {
1789public:
1790 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1791 AsmState::LocationMap *locationMap)
1792 : interfaces(op->getContext()), nameState(op, printerFlags),
1793 printerFlags(printerFlags), locationMap(locationMap) {}
1794 explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
1795 AsmState::LocationMap *locationMap)
1796 : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
1797
1798 /// Initialize the alias state to enable the printing of aliases.
1799 void initializeAliases(Operation *op) {
1800 aliasState.initialize(op, printerFlags, interfaces);
1801 }
1802
1803 /// Get the state used for aliases.
1804 AliasState &getAliasState() { return aliasState; }
1805
1806 /// Get the state used for SSA names.
1807 SSANameState &getSSANameState() { return nameState; }
1808
1809 /// Get the state used for distinct attribute identifiers.
1810 DistinctState &getDistinctState() { return distinctState; }
1811
1812 /// Return the dialects within the context that implement
1813 /// OpAsmDialectInterface.
1814 DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
1815 return interfaces;
1816 }
1817
1818 /// Return the non-dialect resource printers.
1819 auto getResourcePrinters() {
1820 return llvm::make_pointee_range(Range&: externalResourcePrinters);
1821 }
1822
1823 /// Get the printer flags.
1824 const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1825
1826 /// Register the location, line and column, within the buffer that the given
1827 /// operation was printed at.
1828 void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1829 if (locationMap)
1830 (*locationMap)[op] = std::make_pair(x&: line, y&: col);
1831 }
1832
1833 /// Return the referenced dialect resources within the printer.
1834 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
1835 getDialectResources() {
1836 return dialectResources;
1837 }
1838
1839 LogicalResult pushCyclicPrinting(const void *opaquePointer) {
1840 return success(isSuccess: cyclicPrintingStack.insert(X: opaquePointer));
1841 }
1842
1843 void popCyclicPrinting() { cyclicPrintingStack.pop_back(); }
1844
1845private:
1846 /// Collection of OpAsm interfaces implemented in the context.
1847 DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1848
1849 /// A collection of non-dialect resource printers.
1850 SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
1851
1852 /// A set of dialect resources that were referenced during printing.
1853 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
1854
1855 /// The state used for attribute and type aliases.
1856 AliasState aliasState;
1857
1858 /// The state used for SSA value names.
1859 SSANameState nameState;
1860
1861 /// The state used for distinct attribute identifiers.
1862 DistinctState distinctState;
1863
1864 /// Flags that control op output.
1865 OpPrintingFlags printerFlags;
1866
1867 /// An optional location map to be populated.
1868 AsmState::LocationMap *locationMap;
1869
1870 /// Stack of potentially cyclic mutable attributes or type currently being
1871 /// printed.
1872 SetVector<const void *> cyclicPrintingStack;
1873
1874 // Allow direct access to the impl fields.
1875 friend AsmState;
1876};
1877
1878template <typename Range>
1879void printDimensionList(raw_ostream &stream, Range &&shape) {
1880 llvm::interleave(
1881 shape, stream,
1882 [&stream](const auto &dimSize) {
1883 if (ShapedType::isDynamic(dimSize))
1884 stream << "?";
1885 else
1886 stream << dimSize;
1887 },
1888 "x");
1889}
1890
1891} // namespace detail
1892} // namespace mlir
1893
1894/// Verifies the operation and switches to generic op printing if verification
1895/// fails. We need to do this because custom print functions may fail for
1896/// invalid ops.
1897static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
1898 OpPrintingFlags printerFlags) {
1899 if (printerFlags.shouldPrintGenericOpForm() ||
1900 printerFlags.shouldAssumeVerified())
1901 return printerFlags;
1902
1903 // Ignore errors emitted by the verifier. We check the thread id to avoid
1904 // consuming other threads' errors.
1905 auto parentThreadId = llvm::get_threadid();
1906 ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
1907 if (parentThreadId == llvm::get_threadid()) {
1908 LLVM_DEBUG({
1909 diag.print(llvm::dbgs());
1910 llvm::dbgs() << "\n";
1911 });
1912 return success();
1913 }
1914 return failure();
1915 });
1916 if (failed(result: verify(op))) {
1917 LLVM_DEBUG(llvm::dbgs()
1918 << DEBUG_TYPE << ": '" << op->getName()
1919 << "' failed to verify and will be printed in generic form\n");
1920 printerFlags.printGenericOpForm();
1921 }
1922
1923 return printerFlags;
1924}
1925
1926AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1927 LocationMap *locationMap, FallbackAsmResourceMap *map)
1928 : impl(std::make_unique<AsmStateImpl>(
1929 args&: op, args: verifyOpAndAdjustFlags(op, printerFlags), args&: locationMap)) {
1930 if (map)
1931 attachFallbackResourcePrinter(map&: *map);
1932}
1933AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
1934 LocationMap *locationMap, FallbackAsmResourceMap *map)
1935 : impl(std::make_unique<AsmStateImpl>(args&: ctx, args: printerFlags, args&: locationMap)) {
1936 if (map)
1937 attachFallbackResourcePrinter(map&: *map);
1938}
1939AsmState::~AsmState() = default;
1940
1941const OpPrintingFlags &AsmState::getPrinterFlags() const {
1942 return impl->getPrinterFlags();
1943}
1944
1945void AsmState::attachResourcePrinter(
1946 std::unique_ptr<AsmResourcePrinter> printer) {
1947 impl->externalResourcePrinters.emplace_back(Args: std::move(printer));
1948}
1949
1950DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
1951AsmState::getDialectResources() const {
1952 return impl->getDialectResources();
1953}
1954
1955//===----------------------------------------------------------------------===//
1956// AsmPrinter::Impl
1957//===----------------------------------------------------------------------===//
1958
1959AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state)
1960 : os(os), state(state), printerFlags(state.getPrinterFlags()) {}
1961
1962void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1963 // Check to see if we are printing debug information.
1964 if (!printerFlags.shouldPrintDebugInfo())
1965 return;
1966
1967 os << " ";
1968 printLocation(loc, /*allowAlias=*/allowAlias);
1969}
1970
1971void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
1972 bool isTopLevel) {
1973 // If this isn't a top-level location, check for an alias.
1974 if (!isTopLevel && succeeded(result: state.getAliasState().getAlias(attr: loc, os)))
1975 return;
1976
1977 TypeSwitch<LocationAttr>(loc)
1978 .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1979 printLocationInternal(loc.getFallbackLocation(), pretty);
1980 })
1981 .Case<UnknownLoc>([&](UnknownLoc loc) {
1982 if (pretty)
1983 os << "[unknown]";
1984 else
1985 os << "unknown";
1986 })
1987 .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1988 if (pretty)
1989 os << loc.getFilename().getValue();
1990 else
1991 printEscapedString(loc.getFilename());
1992 os << ':' << loc.getLine() << ':' << loc.getColumn();
1993 })
1994 .Case<NameLoc>([&](NameLoc loc) {
1995 printEscapedString(loc.getName());
1996
1997 // Print the child if it isn't unknown.
1998 auto childLoc = loc.getChildLoc();
1999 if (!llvm::isa<UnknownLoc>(childLoc)) {
2000 os << '(';
2001 printLocationInternal(childLoc, pretty);
2002 os << ')';
2003 }
2004 })
2005 .Case<CallSiteLoc>([&](CallSiteLoc loc) {
2006 Location caller = loc.getCaller();
2007 Location callee = loc.getCallee();
2008 if (!pretty)
2009 os << "callsite(";
2010 printLocationInternal(callee, pretty);
2011 if (pretty) {
2012 if (llvm::isa<NameLoc>(callee)) {
2013 if (llvm::isa<FileLineColLoc>(caller)) {
2014 os << " at ";
2015 } else {
2016 os << newLine << " at ";
2017 }
2018 } else {
2019 os << newLine << " at ";
2020 }
2021 } else {
2022 os << " at ";
2023 }
2024 printLocationInternal(caller, pretty);
2025 if (!pretty)
2026 os << ")";
2027 })
2028 .Case<FusedLoc>([&](FusedLoc loc) {
2029 if (!pretty)
2030 os << "fused";
2031 if (Attribute metadata = loc.getMetadata()) {
2032 os << '<';
2033 printAttribute(metadata);
2034 os << '>';
2035 }
2036 os << '[';
2037 interleave(
2038 loc.getLocations(),
2039 [&](Location loc) { printLocationInternal(loc, pretty); },
2040 [&]() { os << ", "; });
2041 os << ']';
2042 });
2043}
2044
2045/// Print a floating point value in a way that the parser will be able to
2046/// round-trip losslessly.
2047static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
2048 // We would like to output the FP constant value in exponential notation,
2049 // but we cannot do this if doing so will lose precision. Check here to
2050 // make sure that we only output it in exponential format if we can parse
2051 // the value back and get the same value.
2052 bool isInf = apValue.isInfinity();
2053 bool isNaN = apValue.isNaN();
2054 if (!isInf && !isNaN) {
2055 SmallString<128> strValue;
2056 apValue.toString(Str&: strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
2057 /*TruncateZero=*/false);
2058
2059 // Check to make sure that the stringized number is not some string like
2060 // "Inf" or NaN, that atof will accept, but the lexer will not. Check
2061 // that the string matches the "[-+]?[0-9]" regex.
2062 assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
2063 ((strValue[0] == '-' || strValue[0] == '+') &&
2064 (strValue[1] >= '0' && strValue[1] <= '9'))) &&
2065 "[-+]?[0-9] regex does not match!");
2066
2067 // Parse back the stringized version and check that the value is equal
2068 // (i.e., there is no precision loss).
2069 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(RHS: apValue)) {
2070 os << strValue;
2071 return;
2072 }
2073
2074 // If it is not, use the default format of APFloat instead of the
2075 // exponential notation.
2076 strValue.clear();
2077 apValue.toString(Str&: strValue);
2078
2079 // Make sure that we can parse the default form as a float.
2080 if (strValue.str().contains(C: '.')) {
2081 os << strValue;
2082 return;
2083 }
2084 }
2085
2086 // Print special values in hexadecimal format. The sign bit should be included
2087 // in the literal.
2088 SmallVector<char, 16> str;
2089 APInt apInt = apValue.bitcastToAPInt();
2090 apInt.toString(Str&: str, /*Radix=*/16, /*Signed=*/false,
2091 /*formatAsCLiteral=*/true);
2092 os << str;
2093}
2094
2095void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
2096 if (printerFlags.shouldPrintDebugInfoPrettyForm())
2097 return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true);
2098
2099 os << "loc(";
2100 if (!allowAlias || failed(result: printAlias(attr: loc)))
2101 printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true);
2102 os << ')';
2103}
2104
2105void AsmPrinter::Impl::printResourceHandle(
2106 const AsmDialectResourceHandle &resource) {
2107 auto *interface = cast<OpAsmDialectInterface>(Val: resource.getDialect());
2108 os << interface->getResourceKey(handle: resource);
2109 state.getDialectResources()[resource.getDialect()].insert(X: resource);
2110}
2111
2112/// Returns true if the given dialect symbol data is simple enough to print in
2113/// the pretty form. This is essentially when the symbol takes the form:
2114/// identifier (`<` body `>`)?
2115static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
2116 // The name must start with an identifier.
2117 if (symName.empty() || !isalpha(symName.front()))
2118 return false;
2119
2120 // Ignore all the characters that are valid in an identifier in the symbol
2121 // name.
2122 symName = symName.drop_while(
2123 F: [](char c) { return llvm::isAlnum(C: c) || c == '.' || c == '_'; });
2124 if (symName.empty())
2125 return true;
2126
2127 // If we got to an unexpected character, then it must be a <>. Check that the
2128 // rest of the symbol is wrapped within <>.
2129 return symName.front() == '<' && symName.back() == '>';
2130}
2131
2132/// Print the given dialect symbol to the stream.
2133static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
2134 StringRef dialectName, StringRef symString) {
2135 os << symPrefix << dialectName;
2136
2137 // If this symbol name is simple enough, print it directly in pretty form,
2138 // otherwise, we print it as an escaped string.
2139 if (isDialectSymbolSimpleEnoughForPrettyForm(symName: symString)) {
2140 os << '.' << symString;
2141 return;
2142 }
2143
2144 os << '<' << symString << '>';
2145}
2146
2147/// Returns true if the given string can be represented as a bare identifier.
2148static bool isBareIdentifier(StringRef name) {
2149 // By making this unsigned, the value passed in to isalnum will always be
2150 // in the range 0-255. This is important when building with MSVC because
2151 // its implementation will assert. This situation can arise when dealing
2152 // with UTF-8 multibyte characters.
2153 if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
2154 return false;
2155 return llvm::all_of(Range: name.drop_front(), P: [](unsigned char c) {
2156 return isalnum(c) || c == '_' || c == '$' || c == '.';
2157 });
2158}
2159
2160/// Print the given string as a keyword, or a quoted and escaped string if it
2161/// has any special or non-printable characters in it.
2162static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
2163 // If it can be represented as a bare identifier, write it directly.
2164 if (isBareIdentifier(name: keyword)) {
2165 os << keyword;
2166 return;
2167 }
2168
2169 // Otherwise, output the keyword wrapped in quotes with proper escaping.
2170 os << "\"";
2171 printEscapedString(Name: keyword, Out&: os);
2172 os << '"';
2173}
2174
2175/// Print the given string as a symbol reference. A symbol reference is
2176/// represented as a string prefixed with '@'. The reference is surrounded with
2177/// ""'s and escaped if it has any special or non-printable characters in it.
2178static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
2179 if (symbolRef.empty()) {
2180 os << "@<<INVALID EMPTY SYMBOL>>";
2181 return;
2182 }
2183 os << '@';
2184 printKeywordOrString(keyword: symbolRef, os);
2185}
2186
2187// Print out a valid ElementsAttr that is succinct and can represent any
2188// potential shape/type, for use when eliding a large ElementsAttr.
2189//
2190// We choose to use a dense resource ElementsAttr literal with conspicuous
2191// content to hopefully alert readers to the fact that this has been elided.
2192static void printElidedElementsAttr(raw_ostream &os) {
2193 os << R"(dense_resource<__elided__>)";
2194}
2195
2196LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
2197 return state.getAliasState().getAlias(attr, os);
2198}
2199
2200LogicalResult AsmPrinter::Impl::printAlias(Type type) {
2201 return state.getAliasState().getAlias(ty: type, os);
2202}
2203
2204void AsmPrinter::Impl::printAttribute(Attribute attr,
2205 AttrTypeElision typeElision) {
2206 if (!attr) {
2207 os << "<<NULL ATTRIBUTE>>";
2208 return;
2209 }
2210
2211 // Try to print an alias for this attribute.
2212 if (succeeded(result: printAlias(attr)))
2213 return;
2214 return printAttributeImpl(attr, typeElision);
2215}
2216
2217void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
2218 AttrTypeElision typeElision) {
2219 if (!isa<BuiltinDialect>(Val: attr.getDialect())) {
2220 printDialectAttribute(attr);
2221 } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) {
2222 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
2223 opaqueAttr.getAttrData());
2224 } else if (llvm::isa<UnitAttr>(Val: attr)) {
2225 os << "unit";
2226 return;
2227 } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) {
2228 os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<";
2229 if (!llvm::isa<UnitAttr>(Val: distinctAttr.getReferencedAttr())) {
2230 printAttribute(attr: distinctAttr.getReferencedAttr());
2231 }
2232 os << '>';
2233 return;
2234 } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
2235 os << '{';
2236 interleaveComma(dictAttr.getValue(),
2237 [&](NamedAttribute attr) { printNamedAttribute(attr); });
2238 os << '}';
2239
2240 } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
2241 Type intType = intAttr.getType();
2242 if (intType.isSignlessInteger(width: 1)) {
2243 os << (intAttr.getValue().getBoolValue() ? "true" : "false");
2244
2245 // Boolean integer attributes always elides the type.
2246 return;
2247 }
2248
2249 // Only print attributes as unsigned if they are explicitly unsigned or are
2250 // signless 1-bit values. Indexes, signed values, and multi-bit signless
2251 // values print as signed.
2252 bool isUnsigned =
2253 intType.isUnsignedInteger() || intType.isSignlessInteger(width: 1);
2254 intAttr.getValue().print(os, !isUnsigned);
2255
2256 // IntegerAttr elides the type if I64.
2257 if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(width: 64))
2258 return;
2259
2260 } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
2261 printFloatValue(floatAttr.getValue(), os);
2262
2263 // FloatAttr elides the type if F64.
2264 if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64())
2265 return;
2266
2267 } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) {
2268 printEscapedString(str: strAttr.getValue());
2269
2270 } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
2271 os << '[';
2272 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
2273 printAttribute(attr, typeElision: AttrTypeElision::May);
2274 });
2275 os << ']';
2276
2277 } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
2278 os << "affine_map<";
2279 affineMapAttr.getValue().print(os);
2280 os << '>';
2281
2282 // AffineMap always elides the type.
2283 return;
2284
2285 } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) {
2286 os << "affine_set<";
2287 integerSetAttr.getValue().print(os);
2288 os << '>';
2289
2290 // IntegerSet always elides the type.
2291 return;
2292
2293 } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) {
2294 printType(type: typeAttr.getValue());
2295
2296 } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
2297 printSymbolReference(refAttr.getRootReference().getValue(), os);
2298 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
2299 os << "::";
2300 printSymbolReference(nestedRef.getValue(), os);
2301 }
2302
2303 } else if (auto intOrFpEltAttr =
2304 llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
2305 if (printerFlags.shouldElideElementsAttr(attr: intOrFpEltAttr)) {
2306 printElidedElementsAttr(os);
2307 } else {
2308 os << "dense<";
2309 printDenseIntOrFPElementsAttr(attr: intOrFpEltAttr, /*allowHex=*/true);
2310 os << '>';
2311 }
2312
2313 } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
2314 if (printerFlags.shouldElideElementsAttr(attr: strEltAttr)) {
2315 printElidedElementsAttr(os);
2316 } else {
2317 os << "dense<";
2318 printDenseStringElementsAttr(attr: strEltAttr);
2319 os << '>';
2320 }
2321
2322 } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
2323 if (printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getIndices()) ||
2324 printerFlags.shouldElideElementsAttr(attr: sparseEltAttr.getValues())) {
2325 printElidedElementsAttr(os);
2326 } else {
2327 os << "sparse<";
2328 DenseIntElementsAttr indices = sparseEltAttr.getIndices();
2329 if (indices.getNumElements() != 0) {
2330 printDenseIntOrFPElementsAttr(attr: indices, /*allowHex=*/false);
2331 os << ", ";
2332 printDenseElementsAttr(attr: sparseEltAttr.getValues(), /*allowHex=*/true);
2333 }
2334 os << '>';
2335 }
2336 } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) {
2337 stridedLayoutAttr.print(os);
2338 } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) {
2339 os << "array<";
2340 printType(type: denseArrayAttr.getElementType());
2341 if (!denseArrayAttr.empty()) {
2342 os << ": ";
2343 printDenseArrayAttr(attr: denseArrayAttr);
2344 }
2345 os << ">";
2346 return;
2347 } else if (auto resourceAttr =
2348 llvm::dyn_cast<DenseResourceElementsAttr>(attr)) {
2349 os << "dense_resource<";
2350 printResourceHandle(resource: resourceAttr.getRawHandle());
2351 os << ">";
2352 } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(Val&: attr)) {
2353 printLocation(loc: locAttr);
2354 } else {
2355 llvm::report_fatal_error(reason: "Unknown builtin attribute");
2356 }
2357 // Don't print the type if we must elide it, or if it is a None type.
2358 if (typeElision != AttrTypeElision::Must) {
2359 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
2360 Type attrType = typedAttr.getType();
2361 if (!llvm::isa<NoneType>(Val: attrType)) {
2362 os << " : ";
2363 printType(type: attrType);
2364 }
2365 }
2366 }
2367}
2368
2369/// Print the integer element of a DenseElementsAttr.
2370static void printDenseIntElement(const APInt &value, raw_ostream &os,
2371 Type type) {
2372 if (type.isInteger(width: 1))
2373 os << (value.getBoolValue() ? "true" : "false");
2374 else
2375 value.print(OS&: os, isSigned: !type.isUnsignedInteger());
2376}
2377
2378static void
2379printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
2380 function_ref<void(unsigned)> printEltFn) {
2381 // Special case for 0-d and splat tensors.
2382 if (isSplat)
2383 return printEltFn(0);
2384
2385 // Special case for degenerate tensors.
2386 auto numElements = type.getNumElements();
2387 if (numElements == 0)
2388 return;
2389
2390 // We use a mixed-radix counter to iterate through the shape. When we bump a
2391 // non-least-significant digit, we emit a close bracket. When we next emit an
2392 // element we re-open all closed brackets.
2393
2394 // The mixed-radix counter, with radices in 'shape'.
2395 int64_t rank = type.getRank();
2396 SmallVector<unsigned, 4> counter(rank, 0);
2397 // The number of brackets that have been opened and not closed.
2398 unsigned openBrackets = 0;
2399
2400 auto shape = type.getShape();
2401 auto bumpCounter = [&] {
2402 // Bump the least significant digit.
2403 ++counter[rank - 1];
2404 // Iterate backwards bubbling back the increment.
2405 for (unsigned i = rank - 1; i > 0; --i)
2406 if (counter[i] >= shape[i]) {
2407 // Index 'i' is rolled over. Bump (i-1) and close a bracket.
2408 counter[i] = 0;
2409 ++counter[i - 1];
2410 --openBrackets;
2411 os << ']';
2412 }
2413 };
2414
2415 for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
2416 if (idx != 0)
2417 os << ", ";
2418 while (openBrackets++ < rank)
2419 os << '[';
2420 openBrackets = rank;
2421 printEltFn(idx);
2422 bumpCounter();
2423 }
2424 while (openBrackets-- > 0)
2425 os << ']';
2426}
2427
2428void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
2429 bool allowHex) {
2430 if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr))
2431 return printDenseStringElementsAttr(attr: stringAttr);
2432
2433 printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr),
2434 allowHex);
2435}
2436
2437void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
2438 DenseIntOrFPElementsAttr attr, bool allowHex) {
2439 auto type = attr.getType();
2440 auto elementType = type.getElementType();
2441
2442 // Check to see if we should format this attribute as a hex string.
2443 if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr: attr)) {
2444 ArrayRef<char> rawData = attr.getRawData();
2445 if (llvm::endianness::native == llvm::endianness::big) {
2446 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
2447 // machines. It is converted here to print in LE format.
2448 SmallVector<char, 64> outDataVec(rawData.size());
2449 MutableArrayRef<char> convRawData(outDataVec);
2450 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
2451 rawData, convRawData, type);
2452 printHexString(data: convRawData);
2453 } else {
2454 printHexString(data: rawData);
2455 }
2456
2457 return;
2458 }
2459
2460 if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
2461 Type complexElementType = complexTy.getElementType();
2462 // Note: The if and else below had a common lambda function which invoked
2463 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
2464 // and hence was replaced.
2465 if (llvm::isa<IntegerType>(Val: complexElementType)) {
2466 auto valueIt = attr.value_begin<std::complex<APInt>>();
2467 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2468 auto complexValue = *(valueIt + index);
2469 os << "(";
2470 printDenseIntElement(complexValue.real(), os, complexElementType);
2471 os << ",";
2472 printDenseIntElement(complexValue.imag(), os, complexElementType);
2473 os << ")";
2474 });
2475 } else {
2476 auto valueIt = attr.value_begin<std::complex<APFloat>>();
2477 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2478 auto complexValue = *(valueIt + index);
2479 os << "(";
2480 printFloatValue(complexValue.real(), os);
2481 os << ",";
2482 printFloatValue(complexValue.imag(), os);
2483 os << ")";
2484 });
2485 }
2486 } else if (elementType.isIntOrIndex()) {
2487 auto valueIt = attr.value_begin<APInt>();
2488 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2489 printDenseIntElement(*(valueIt + index), os, elementType);
2490 });
2491 } else {
2492 assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
2493 auto valueIt = attr.value_begin<APFloat>();
2494 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2495 printFloatValue(*(valueIt + index), os);
2496 });
2497 }
2498}
2499
2500void AsmPrinter::Impl::printDenseStringElementsAttr(
2501 DenseStringElementsAttr attr) {
2502 ArrayRef<StringRef> data = attr.getRawStringData();
2503 auto printFn = [&](unsigned index) { printEscapedString(str: data[index]); };
2504 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
2505}
2506
2507void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
2508 Type type = attr.getElementType();
2509 unsigned bitwidth = type.isInteger(width: 1) ? 8 : type.getIntOrFloatBitWidth();
2510 unsigned byteSize = bitwidth / 8;
2511 ArrayRef<char> data = attr.getRawData();
2512
2513 auto printElementAt = [&](unsigned i) {
2514 APInt value(bitwidth, 0);
2515 if (bitwidth) {
2516 llvm::LoadIntFromMemory(
2517 IntVal&: value, Src: reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
2518 LoadBytes: byteSize);
2519 }
2520 // Print the data as-is or as a float.
2521 if (type.isIntOrIndex()) {
2522 printDenseIntElement(value, os&: getStream(), type);
2523 } else {
2524 APFloat fltVal(llvm::cast<FloatType>(Val&: type).getFloatSemantics(), value);
2525 printFloatValue(apValue: fltVal, os&: getStream());
2526 }
2527 };
2528 llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
2529 printElementAt);
2530}
2531
2532void AsmPrinter::Impl::printType(Type type) {
2533 if (!type) {
2534 os << "<<NULL TYPE>>";
2535 return;
2536 }
2537
2538 // Try to print an alias for this type.
2539 if (succeeded(result: printAlias(type)))
2540 return;
2541 return printTypeImpl(type);
2542}
2543
2544void AsmPrinter::Impl::printTypeImpl(Type type) {
2545 TypeSwitch<Type>(type)
2546 .Case<OpaqueType>([&](OpaqueType opaqueTy) {
2547 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
2548 opaqueTy.getTypeData());
2549 })
2550 .Case<IndexType>([&](Type) { os << "index"; })
2551 .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
2552 .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
2553 .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
2554 .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
2555 .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
2556 .Case<BFloat16Type>([&](Type) { os << "bf16"; })
2557 .Case<Float16Type>([&](Type) { os << "f16"; })
2558 .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
2559 .Case<Float32Type>([&](Type) { os << "f32"; })
2560 .Case<Float64Type>([&](Type) { os << "f64"; })
2561 .Case<Float80Type>([&](Type) { os << "f80"; })
2562 .Case<Float128Type>([&](Type) { os << "f128"; })
2563 .Case<IntegerType>([&](IntegerType integerTy) {
2564 if (integerTy.isSigned())
2565 os << 's';
2566 else if (integerTy.isUnsigned())
2567 os << 'u';
2568 os << 'i' << integerTy.getWidth();
2569 })
2570 .Case<FunctionType>([&](FunctionType funcTy) {
2571 os << '(';
2572 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
2573 os << ") -> ";
2574 ArrayRef<Type> results = funcTy.getResults();
2575 if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) {
2576 printType(results[0]);
2577 } else {
2578 os << '(';
2579 interleaveComma(results, [&](Type ty) { printType(ty); });
2580 os << ')';
2581 }
2582 })
2583 .Case<VectorType>([&](VectorType vectorTy) {
2584 auto scalableDims = vectorTy.getScalableDims();
2585 os << "vector<";
2586 auto vShape = vectorTy.getShape();
2587 unsigned lastDim = vShape.size();
2588 unsigned dimIdx = 0;
2589 for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
2590 if (!scalableDims.empty() && scalableDims[dimIdx])
2591 os << '[';
2592 os << vShape[dimIdx];
2593 if (!scalableDims.empty() && scalableDims[dimIdx])
2594 os << ']';
2595 os << 'x';
2596 }
2597 printType(vectorTy.getElementType());
2598 os << '>';
2599 })
2600 .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2601 os << "tensor<";
2602 printDimensionList(tensorTy.getShape());
2603 if (!tensorTy.getShape().empty())
2604 os << 'x';
2605 printType(tensorTy.getElementType());
2606 // Only print the encoding attribute value if set.
2607 if (tensorTy.getEncoding()) {
2608 os << ", ";
2609 printAttribute(tensorTy.getEncoding());
2610 }
2611 os << '>';
2612 })
2613 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2614 os << "tensor<*x";
2615 printType(tensorTy.getElementType());
2616 os << '>';
2617 })
2618 .Case<MemRefType>([&](MemRefType memrefTy) {
2619 os << "memref<";
2620 printDimensionList(memrefTy.getShape());
2621 if (!memrefTy.getShape().empty())
2622 os << 'x';
2623 printType(memrefTy.getElementType());
2624 MemRefLayoutAttrInterface layout = memrefTy.getLayout();
2625 if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
2626 os << ", ";
2627 printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2628 }
2629 // Only print the memory space if it is the non-default one.
2630 if (memrefTy.getMemorySpace()) {
2631 os << ", ";
2632 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2633 }
2634 os << '>';
2635 })
2636 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2637 os << "memref<*x";
2638 printType(memrefTy.getElementType());
2639 // Only print the memory space if it is the non-default one.
2640 if (memrefTy.getMemorySpace()) {
2641 os << ", ";
2642 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2643 }
2644 os << '>';
2645 })
2646 .Case<ComplexType>([&](ComplexType complexTy) {
2647 os << "complex<";
2648 printType(complexTy.getElementType());
2649 os << '>';
2650 })
2651 .Case<TupleType>([&](TupleType tupleTy) {
2652 os << "tuple<";
2653 interleaveComma(tupleTy.getTypes(),
2654 [&](Type type) { printType(type); });
2655 os << '>';
2656 })
2657 .Case<NoneType>([&](Type) { os << "none"; })
2658 .Default([&](Type type) { return printDialectType(type); });
2659}
2660
2661void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2662 ArrayRef<StringRef> elidedAttrs,
2663 bool withKeyword) {
2664 // If there are no attributes, then there is nothing to be done.
2665 if (attrs.empty())
2666 return;
2667
2668 // Functor used to print a filtered attribute list.
2669 auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2670 // Print the 'attributes' keyword if necessary.
2671 if (withKeyword)
2672 os << " attributes";
2673
2674 // Otherwise, print them all out in braces.
2675 os << " {";
2676 interleaveComma(filteredAttrs,
2677 [&](NamedAttribute attr) { printNamedAttribute(attr); });
2678 os << '}';
2679 };
2680
2681 // If no attributes are elided, we can directly print with no filtering.
2682 if (elidedAttrs.empty())
2683 return printFilteredAttributesFn(attrs);
2684
2685 // Otherwise, filter out any attributes that shouldn't be included.
2686 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2687 elidedAttrs.end());
2688 auto filteredAttrs = llvm::make_filter_range(Range&: attrs, Pred: [&](NamedAttribute attr) {
2689 return !elidedAttrsSet.contains(attr.getName().strref());
2690 });
2691 if (!filteredAttrs.empty())
2692 printFilteredAttributesFn(filteredAttrs);
2693}
2694void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2695 // Print the name without quotes if possible.
2696 ::printKeywordOrString(keyword: attr.getName().strref(), os);
2697
2698 // Pretty printing elides the attribute value for unit attributes.
2699 if (llvm::isa<UnitAttr>(Val: attr.getValue()))
2700 return;
2701
2702 os << " = ";
2703 printAttribute(attr: attr.getValue());
2704}
2705
2706void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2707 auto &dialect = attr.getDialect();
2708
2709 // Ask the dialect to serialize the attribute to a string.
2710 std::string attrName;
2711 {
2712 llvm::raw_string_ostream attrNameStr(attrName);
2713 Impl subPrinter(attrNameStr, state);
2714 DialectAsmPrinter printer(subPrinter);
2715 dialect.printAttribute(attr, printer);
2716 }
2717 printDialectSymbol(os, symPrefix: "#", dialectName: dialect.getNamespace(), symString: attrName);
2718}
2719
2720void AsmPrinter::Impl::printDialectType(Type type) {
2721 auto &dialect = type.getDialect();
2722
2723 // Ask the dialect to serialize the type to a string.
2724 std::string typeName;
2725 {
2726 llvm::raw_string_ostream typeNameStr(typeName);
2727 Impl subPrinter(typeNameStr, state);
2728 DialectAsmPrinter printer(subPrinter);
2729 dialect.printType(type, printer);
2730 }
2731 printDialectSymbol(os, symPrefix: "!", dialectName: dialect.getNamespace(), symString: typeName);
2732}
2733
2734void AsmPrinter::Impl::printEscapedString(StringRef str) {
2735 os << "\"";
2736 llvm::printEscapedString(Name: str, Out&: os);
2737 os << "\"";
2738}
2739
2740void AsmPrinter::Impl::printHexString(StringRef str) {
2741 os << "\"0x" << llvm::toHex(Input: str) << "\"";
2742}
2743void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
2744 printHexString(str: StringRef(data.data(), data.size()));
2745}
2746
2747LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
2748 return state.pushCyclicPrinting(opaquePointer);
2749}
2750
2751void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
2752
2753void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
2754 detail::printDimensionList(stream&: os, shape);
2755}
2756
2757//===--------------------------------------------------------------------===//
2758// AsmPrinter
2759//===--------------------------------------------------------------------===//
2760
2761AsmPrinter::~AsmPrinter() = default;
2762
2763raw_ostream &AsmPrinter::getStream() const {
2764 assert(impl && "expected AsmPrinter::getStream to be overriden");
2765 return impl->getStream();
2766}
2767
2768/// Print the given floating point value in a stablized form.
2769void AsmPrinter::printFloat(const APFloat &value) {
2770 assert(impl && "expected AsmPrinter::printFloat to be overriden");
2771 printFloatValue(apValue: value, os&: impl->getStream());
2772}
2773
2774void AsmPrinter::printType(Type type) {
2775 assert(impl && "expected AsmPrinter::printType to be overriden");
2776 impl->printType(type);
2777}
2778
2779void AsmPrinter::printAttribute(Attribute attr) {
2780 assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2781 impl->printAttribute(attr);
2782}
2783
2784LogicalResult AsmPrinter::printAlias(Attribute attr) {
2785 assert(impl && "expected AsmPrinter::printAlias to be overriden");
2786 return impl->printAlias(attr);
2787}
2788
2789LogicalResult AsmPrinter::printAlias(Type type) {
2790 assert(impl && "expected AsmPrinter::printAlias to be overriden");
2791 return impl->printAlias(type);
2792}
2793
2794void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2795 assert(impl &&
2796 "expected AsmPrinter::printAttributeWithoutType to be overriden");
2797 impl->printAttribute(attr, typeElision: Impl::AttrTypeElision::Must);
2798}
2799
2800void AsmPrinter::printKeywordOrString(StringRef keyword) {
2801 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2802 ::printKeywordOrString(keyword, os&: impl->getStream());
2803}
2804
2805void AsmPrinter::printString(StringRef keyword) {
2806 assert(impl && "expected AsmPrinter::printString to be overriden");
2807 *this << '"';
2808 printEscapedString(Name: keyword, Out&: getStream());
2809 *this << '"';
2810}
2811
2812void AsmPrinter::printSymbolName(StringRef symbolRef) {
2813 assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2814 ::printSymbolReference(symbolRef, os&: impl->getStream());
2815}
2816
2817void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
2818 assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
2819 impl->printResourceHandle(resource);
2820}
2821
2822void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
2823 detail::printDimensionList(stream&: getStream(), shape);
2824}
2825
2826LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
2827 return impl->pushCyclicPrinting(opaquePointer);
2828}
2829
2830void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); }
2831
2832//===----------------------------------------------------------------------===//
2833// Affine expressions and maps
2834//===----------------------------------------------------------------------===//
2835
2836void AsmPrinter::Impl::printAffineExpr(
2837 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2838 printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak, printValueName);
2839}
2840
2841void AsmPrinter::Impl::printAffineExprInternal(
2842 AffineExpr expr, BindingStrength enclosingTightness,
2843 function_ref<void(unsigned, bool)> printValueName) {
2844 const char *binopSpelling = nullptr;
2845 switch (expr.getKind()) {
2846 case AffineExprKind::SymbolId: {
2847 unsigned pos = cast<AffineSymbolExpr>(Val&: expr).getPosition();
2848 if (printValueName)
2849 printValueName(pos, /*isSymbol=*/true);
2850 else
2851 os << 's' << pos;
2852 return;
2853 }
2854 case AffineExprKind::DimId: {
2855 unsigned pos = cast<AffineDimExpr>(Val&: expr).getPosition();
2856 if (printValueName)
2857 printValueName(pos, /*isSymbol=*/false);
2858 else
2859 os << 'd' << pos;
2860 return;
2861 }
2862 case AffineExprKind::Constant:
2863 os << cast<AffineConstantExpr>(Val&: expr).getValue();
2864 return;
2865 case AffineExprKind::Add:
2866 binopSpelling = " + ";
2867 break;
2868 case AffineExprKind::Mul:
2869 binopSpelling = " * ";
2870 break;
2871 case AffineExprKind::FloorDiv:
2872 binopSpelling = " floordiv ";
2873 break;
2874 case AffineExprKind::CeilDiv:
2875 binopSpelling = " ceildiv ";
2876 break;
2877 case AffineExprKind::Mod:
2878 binopSpelling = " mod ";
2879 break;
2880 }
2881
2882 auto binOp = cast<AffineBinaryOpExpr>(Val&: expr);
2883 AffineExpr lhsExpr = binOp.getLHS();
2884 AffineExpr rhsExpr = binOp.getRHS();
2885
2886 // Handle tightly binding binary operators.
2887 if (binOp.getKind() != AffineExprKind::Add) {
2888 if (enclosingTightness == BindingStrength::Strong)
2889 os << '(';
2890
2891 // Pretty print multiplication with -1.
2892 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr);
2893 if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2894 rhsConst.getValue() == -1) {
2895 os << "-";
2896 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName);
2897 if (enclosingTightness == BindingStrength::Strong)
2898 os << ')';
2899 return;
2900 }
2901
2902 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Strong, printValueName);
2903
2904 os << binopSpelling;
2905 printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Strong, printValueName);
2906
2907 if (enclosingTightness == BindingStrength::Strong)
2908 os << ')';
2909 return;
2910 }
2911
2912 // Print out special "pretty" forms for add.
2913 if (enclosingTightness == BindingStrength::Strong)
2914 os << '(';
2915
2916 // Pretty print addition to a product that has a negative operand as a
2917 // subtraction.
2918 if (auto rhs = dyn_cast<AffineBinaryOpExpr>(Val&: rhsExpr)) {
2919 if (rhs.getKind() == AffineExprKind::Mul) {
2920 AffineExpr rrhsExpr = rhs.getRHS();
2921 if (auto rrhs = dyn_cast<AffineConstantExpr>(Val&: rrhsExpr)) {
2922 if (rrhs.getValue() == -1) {
2923 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak,
2924 printValueName);
2925 os << " - ";
2926 if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2927 printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong,
2928 printValueName);
2929 } else {
2930 printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Weak,
2931 printValueName);
2932 }
2933
2934 if (enclosingTightness == BindingStrength::Strong)
2935 os << ')';
2936 return;
2937 }
2938
2939 if (rrhs.getValue() < -1) {
2940 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak,
2941 printValueName);
2942 os << " - ";
2943 printAffineExprInternal(expr: rhs.getLHS(), enclosingTightness: BindingStrength::Strong,
2944 printValueName);
2945 os << " * " << -rrhs.getValue();
2946 if (enclosingTightness == BindingStrength::Strong)
2947 os << ')';
2948 return;
2949 }
2950 }
2951 }
2952 }
2953
2954 // Pretty print addition to a negative number as a subtraction.
2955 if (auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhsExpr)) {
2956 if (rhsConst.getValue() < 0) {
2957 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName);
2958 os << " - " << -rhsConst.getValue();
2959 if (enclosingTightness == BindingStrength::Strong)
2960 os << ')';
2961 return;
2962 }
2963 }
2964
2965 printAffineExprInternal(expr: lhsExpr, enclosingTightness: BindingStrength::Weak, printValueName);
2966
2967 os << " + ";
2968 printAffineExprInternal(expr: rhsExpr, enclosingTightness: BindingStrength::Weak, printValueName);
2969
2970 if (enclosingTightness == BindingStrength::Strong)
2971 os << ')';
2972}
2973
2974void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2975 printAffineExprInternal(expr, enclosingTightness: BindingStrength::Weak);
2976 isEq ? os << " == 0" : os << " >= 0";
2977}
2978
2979void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2980 // Dimension identifiers.
2981 os << '(';
2982 for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2983 os << 'd' << i << ", ";
2984 if (map.getNumDims() >= 1)
2985 os << 'd' << map.getNumDims() - 1;
2986 os << ')';
2987
2988 // Symbolic identifiers.
2989 if (map.getNumSymbols() != 0) {
2990 os << '[';
2991 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2992 os << 's' << i << ", ";
2993 if (map.getNumSymbols() >= 1)
2994 os << 's' << map.getNumSymbols() - 1;
2995 os << ']';
2996 }
2997
2998 // Result affine expressions.
2999 os << " -> (";
3000 interleaveComma(c: map.getResults(),
3001 eachFn: [&](AffineExpr expr) { printAffineExpr(expr); });
3002 os << ')';
3003}
3004
3005void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
3006 // Dimension identifiers.
3007 os << '(';
3008 for (unsigned i = 1; i < set.getNumDims(); ++i)
3009 os << 'd' << i - 1 << ", ";
3010 if (set.getNumDims() >= 1)
3011 os << 'd' << set.getNumDims() - 1;
3012 os << ')';
3013
3014 // Symbolic identifiers.
3015 if (set.getNumSymbols() != 0) {
3016 os << '[';
3017 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
3018 os << 's' << i << ", ";
3019 if (set.getNumSymbols() >= 1)
3020 os << 's' << set.getNumSymbols() - 1;
3021 os << ']';
3022 }
3023
3024 // Print constraints.
3025 os << " : (";
3026 int numConstraints = set.getNumConstraints();
3027 for (int i = 1; i < numConstraints; ++i) {
3028 printAffineConstraint(expr: set.getConstraint(idx: i - 1), isEq: set.isEq(idx: i - 1));
3029 os << ", ";
3030 }
3031 if (numConstraints >= 1)
3032 printAffineConstraint(expr: set.getConstraint(idx: numConstraints - 1),
3033 isEq: set.isEq(idx: numConstraints - 1));
3034 os << ')';
3035}
3036
3037//===----------------------------------------------------------------------===//
3038// OperationPrinter
3039//===----------------------------------------------------------------------===//
3040
3041namespace {
3042/// This class contains the logic for printing operations, regions, and blocks.
3043class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
3044public:
3045 using Impl = AsmPrinter::Impl;
3046 using Impl::printType;
3047
3048 explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
3049 : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
3050
3051 /// Print the given top-level operation.
3052 void printTopLevelOperation(Operation *op);
3053
3054 /// Print the given operation, including its left-hand side and its right-hand
3055 /// side, with its indent and location.
3056 void printFullOpWithIndentAndLoc(Operation *op);
3057 /// Print the given operation, including its left-hand side and its right-hand
3058 /// side, but not including indentation and location.
3059 void printFullOp(Operation *op);
3060 /// Print the right-hand size of the given operation in the custom or generic
3061 /// form.
3062 void printCustomOrGenericOp(Operation *op) override;
3063 /// Print the right-hand side of the given operation in the generic form.
3064 void printGenericOp(Operation *op, bool printOpName) override;
3065
3066 /// Print the name of the given block.
3067 void printBlockName(Block *block);
3068
3069 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
3070 /// block are not printed. If 'printBlockTerminator' is false, the terminator
3071 /// operation of the block is not printed.
3072 void print(Block *block, bool printBlockArgs = true,
3073 bool printBlockTerminator = true);
3074
3075 /// Print the ID of the given value, optionally with its result number.
3076 void printValueID(Value value, bool printResultNo = true,
3077 raw_ostream *streamOverride = nullptr) const;
3078
3079 /// Print the ID of the given operation.
3080 void printOperationID(Operation *op,
3081 raw_ostream *streamOverride = nullptr) const;
3082
3083 //===--------------------------------------------------------------------===//
3084 // OpAsmPrinter methods
3085 //===--------------------------------------------------------------------===//
3086
3087 /// Print a loc(...) specifier if printing debug info is enabled. Locations
3088 /// may be deferred with an alias.
3089 void printOptionalLocationSpecifier(Location loc) override {
3090 printTrailingLocation(loc);
3091 }
3092
3093 /// Print a newline and indent the printer to the start of the current
3094 /// operation.
3095 void printNewline() override {
3096 os << newLine;
3097 os.indent(NumSpaces: currentIndent);
3098 }
3099
3100 /// Increase indentation.
3101 void increaseIndent() override { currentIndent += indentWidth; }
3102
3103 /// Decrease indentation.
3104 void decreaseIndent() override { currentIndent -= indentWidth; }
3105
3106 /// Print a block argument in the usual format of:
3107 /// %ssaName : type {attr1=42} loc("here")
3108 /// where location printing is controlled by the standard internal option.
3109 /// You may pass omitType=true to not print a type, and pass an empty
3110 /// attribute list if you don't care for attributes.
3111 void printRegionArgument(BlockArgument arg,
3112 ArrayRef<NamedAttribute> argAttrs = {},
3113 bool omitType = false) override;
3114
3115 /// Print the ID for the given value.
3116 void printOperand(Value value) override { printValueID(value); }
3117 void printOperand(Value value, raw_ostream &os) override {
3118 printValueID(value, /*printResultNo=*/true, streamOverride: &os);
3119 }
3120
3121 /// Print an optional attribute dictionary with a given set of elided values.
3122 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
3123 ArrayRef<StringRef> elidedAttrs = {}) override {
3124 Impl::printOptionalAttrDict(attrs, elidedAttrs);
3125 }
3126 void printOptionalAttrDictWithKeyword(
3127 ArrayRef<NamedAttribute> attrs,
3128 ArrayRef<StringRef> elidedAttrs = {}) override {
3129 Impl::printOptionalAttrDict(attrs, elidedAttrs,
3130 /*withKeyword=*/true);
3131 }
3132
3133 /// Print the given successor.
3134 void printSuccessor(Block *successor) override;
3135
3136 /// Print an operation successor with the operands used for the block
3137 /// arguments.
3138 void printSuccessorAndUseList(Block *successor,
3139 ValueRange succOperands) override;
3140
3141 /// Print the given region.
3142 void printRegion(Region &region, bool printEntryBlockArgs,
3143 bool printBlockTerminators, bool printEmptyBlock) override;
3144
3145 /// Renumber the arguments for the specified region to the same names as the
3146 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
3147 /// operations. If any entry in namesToUse is null, the corresponding
3148 /// argument name is left alone.
3149 void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
3150 state.getSSANameState().shadowRegionArgs(region, namesToUse);
3151 }
3152
3153 /// Print the given affine map with the symbol and dimension operands printed
3154 /// inline with the map.
3155 void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3156 ValueRange operands) override;
3157
3158 /// Print the given affine expression with the symbol and dimension operands
3159 /// printed inline with the expression.
3160 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
3161 ValueRange symOperands) override;
3162
3163 /// Print users of this operation or id of this operation if it has no result.
3164 void printUsersComment(Operation *op);
3165
3166 /// Print users of this block arg.
3167 void printUsersComment(BlockArgument arg);
3168
3169 /// Print the users of a value.
3170 void printValueUsers(Value value);
3171
3172 /// Print either the ids of the result values or the id of the operation if
3173 /// the operation has no results.
3174 void printUserIDs(Operation *user, bool prefixComma = false);
3175
3176private:
3177 /// This class represents a resource builder implementation for the MLIR
3178 /// textual assembly format.
3179 class ResourceBuilder : public AsmResourceBuilder {
3180 public:
3181 using ValueFn = function_ref<void(raw_ostream &)>;
3182 using PrintFn = function_ref<void(StringRef, ValueFn)>;
3183
3184 ResourceBuilder(PrintFn printFn) : printFn(printFn) {}
3185 ~ResourceBuilder() override = default;
3186
3187 void buildBool(StringRef key, bool data) final {
3188 printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false"); });
3189 }
3190
3191 void buildString(StringRef key, StringRef data) final {
3192 printFn(key, [&](raw_ostream &os) {
3193 os << "\"";
3194 llvm::printEscapedString(Name: data, Out&: os);
3195 os << "\"";
3196 });
3197 }
3198
3199 void buildBlob(StringRef key, ArrayRef<char> data,
3200 uint32_t dataAlignment) final {
3201 printFn(key, [&](raw_ostream &os) {
3202 // Store the blob in a hex string containing the alignment and the data.
3203 llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
3204 os << "\"0x"
3205 << llvm::toHex(Input: StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
3206 sizeof(dataAlignment)))
3207 << llvm::toHex(Input: StringRef(data.data(), data.size())) << "\"";
3208 });
3209 }
3210
3211 private:
3212 PrintFn printFn;
3213 };
3214
3215 /// Print the metadata dictionary for the file, eliding it if it is empty.
3216 void printFileMetadataDictionary(Operation *op);
3217
3218 /// Print the resource sections for the file metadata dictionary.
3219 /// `checkAddMetadataDict` is used to indicate that metadata is going to be
3220 /// added, and the file metadata dictionary should be started if it hasn't
3221 /// yet.
3222 void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
3223 Operation *op);
3224
3225 // Contains the stack of default dialects to use when printing regions.
3226 // A new dialect is pushed to the stack before parsing regions nested under an
3227 // operation implementing `OpAsmOpInterface`, and popped when done. At the
3228 // top-level we start with "builtin" as the default, so that the top-level
3229 // `module` operation prints as-is.
3230 SmallVector<StringRef> defaultDialectStack{"builtin"};
3231
3232 /// The number of spaces used for indenting nested operations.
3233 const static unsigned indentWidth = 2;
3234
3235 // This is the current indentation level for nested structures.
3236 unsigned currentIndent = 0;
3237};
3238} // namespace
3239
3240void OperationPrinter::printTopLevelOperation(Operation *op) {
3241 // Output the aliases at the top level that can't be deferred.
3242 state.getAliasState().printNonDeferredAliases(p&: *this, newLine);
3243
3244 // Print the module.
3245 printFullOpWithIndentAndLoc(op);
3246 os << newLine;
3247
3248 // Output the aliases at the top level that can be deferred.
3249 state.getAliasState().printDeferredAliases(p&: *this, newLine);
3250
3251 // Output any file level metadata.
3252 printFileMetadataDictionary(op);
3253}
3254
3255void OperationPrinter::printFileMetadataDictionary(Operation *op) {
3256 bool sawMetadataEntry = false;
3257 auto checkAddMetadataDict = [&] {
3258 if (!std::exchange(obj&: sawMetadataEntry, new_val: true))
3259 os << newLine << "{-#" << newLine;
3260 };
3261
3262 // Add the various types of metadata.
3263 printResourceFileMetadata(checkAddMetadataDict, op);
3264
3265 // If the file dictionary exists, close it.
3266 if (sawMetadataEntry)
3267 os << newLine << "#-}" << newLine;
3268}
3269
3270void OperationPrinter::printResourceFileMetadata(
3271 function_ref<void()> checkAddMetadataDict, Operation *op) {
3272 // Functor used to add data entries to the file metadata dictionary.
3273 bool hadResource = false;
3274 bool needResourceComma = false;
3275 bool needEntryComma = false;
3276 auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
3277 auto &&...providerArgs) {
3278 bool hadEntry = false;
3279 auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
3280 checkAddMetadataDict();
3281
3282 auto printFormatting = [&]() {
3283 // Emit the top-level resource entry if we haven't yet.
3284 if (!std::exchange(obj&: hadResource, new_val: true)) {
3285 if (needResourceComma)
3286 os << "," << newLine;
3287 os << " " << dictName << "_resources: {" << newLine;
3288 }
3289 // Emit the parent resource entry if we haven't yet.
3290 if (!std::exchange(obj&: hadEntry, new_val: true)) {
3291 if (needEntryComma)
3292 os << "," << newLine;
3293 os << " " << name << ": {" << newLine;
3294 } else {
3295 os << "," << newLine;
3296 }
3297 };
3298
3299 std::optional<uint64_t> charLimit =
3300 printerFlags.getLargeResourceStringLimit();
3301 if (charLimit.has_value()) {
3302 std::string resourceStr;
3303 llvm::raw_string_ostream ss(resourceStr);
3304 valueFn(ss);
3305
3306 // Only print entry if it's string is small enough
3307 if (resourceStr.size() > charLimit.value())
3308 return;
3309
3310 printFormatting();
3311 os << " " << key << ": " << resourceStr;
3312 } else {
3313 printFormatting();
3314 os << " " << key << ": ";
3315 valueFn(os);
3316 }
3317 };
3318 ResourceBuilder entryBuilder(printFn);
3319 provider.buildResources(op, providerArgs..., entryBuilder);
3320
3321 needEntryComma |= hadEntry;
3322 if (hadEntry)
3323 os << newLine << " }";
3324 };
3325
3326 // Print the `dialect_resources` section if we have any dialects with
3327 // resources.
3328 for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) {
3329 auto &dialectResources = state.getDialectResources();
3330 StringRef name = interface.getDialect()->getNamespace();
3331 auto it = dialectResources.find(Val: interface.getDialect());
3332 if (it != dialectResources.end())
3333 processProvider("dialect", name, interface, it->second);
3334 else
3335 processProvider("dialect", name, interface,
3336 SetVector<AsmDialectResourceHandle>());
3337 }
3338 if (hadResource)
3339 os << newLine << " }";
3340
3341 // Print the `external_resources` section if we have any external clients with
3342 // resources.
3343 needEntryComma = false;
3344 needResourceComma = hadResource;
3345 hadResource = false;
3346 for (const auto &printer : state.getResourcePrinters())
3347 processProvider("external", printer.getName(), printer);
3348 if (hadResource)
3349 os << newLine << " }";
3350}
3351
3352/// Print a block argument in the usual format of:
3353/// %ssaName : type {attr1=42} loc("here")
3354/// where location printing is controlled by the standard internal option.
3355/// You may pass omitType=true to not print a type, and pass an empty
3356/// attribute list if you don't care for attributes.
3357void OperationPrinter::printRegionArgument(BlockArgument arg,
3358 ArrayRef<NamedAttribute> argAttrs,
3359 bool omitType) {
3360 printOperand(value: arg);
3361 if (!omitType) {
3362 os << ": ";
3363 printType(type: arg.getType());
3364 }
3365 printOptionalAttrDict(attrs: argAttrs);
3366 // TODO: We should allow location aliases on block arguments.
3367 printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false);
3368}
3369
3370void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) {
3371 // Track the location of this operation.
3372 state.registerOperationLocation(op, line: newLine.curLine, col: currentIndent);
3373
3374 os.indent(NumSpaces: currentIndent);
3375 printFullOp(op);
3376 printTrailingLocation(loc: op->getLoc());
3377 if (printerFlags.shouldPrintValueUsers())
3378 printUsersComment(op);
3379}
3380
3381void OperationPrinter::printFullOp(Operation *op) {
3382 if (size_t numResults = op->getNumResults()) {
3383 auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
3384 printValueID(value: op->getResult(idx: resultNo), /*printResultNo=*/false);
3385 if (resultCount > 1)
3386 os << ':' << resultCount;
3387 };
3388
3389 // Check to see if this operation has multiple result groups.
3390 ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op);
3391 if (!resultGroups.empty()) {
3392 // Interleave the groups excluding the last one, this one will be handled
3393 // separately.
3394 interleaveComma(c: llvm::seq<int>(Begin: 0, End: resultGroups.size() - 1), eachFn: [&](int i) {
3395 printResultGroup(resultGroups[i],
3396 resultGroups[i + 1] - resultGroups[i]);
3397 });
3398 os << ", ";
3399 printResultGroup(resultGroups.back(), numResults - resultGroups.back());
3400
3401 } else {
3402 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
3403 }
3404
3405 os << " = ";
3406 }
3407
3408 printCustomOrGenericOp(op);
3409}
3410
3411void OperationPrinter::printUsersComment(Operation *op) {
3412 unsigned numResults = op->getNumResults();
3413 if (!numResults && op->getNumOperands()) {
3414 os << " // id: ";
3415 printOperationID(op);
3416 } else if (numResults && op->use_empty()) {
3417 os << " // unused";
3418 } else if (numResults && !op->use_empty()) {
3419 // Print "user" if the operation has one result used to compute one other
3420 // result, or is used in one operation with no result.
3421 unsigned usedInNResults = 0;
3422 unsigned usedInNOperations = 0;
3423 SmallPtrSet<Operation *, 1> userSet;
3424 for (Operation *user : op->getUsers()) {
3425 if (userSet.insert(Ptr: user).second) {
3426 ++usedInNOperations;
3427 usedInNResults += user->getNumResults();
3428 }
3429 }
3430
3431 // We already know that users is not empty.
3432 bool exactlyOneUniqueUse =
3433 usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
3434 os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
3435 bool shouldPrintBrackets = numResults > 1;
3436 auto printOpResult = [&](OpResult opResult) {
3437 if (shouldPrintBrackets)
3438 os << "(";
3439 printValueUsers(value: opResult);
3440 if (shouldPrintBrackets)
3441 os << ")";
3442 };
3443
3444 interleaveComma(c: op->getResults(), eachFn: printOpResult);
3445 }
3446}
3447
3448void OperationPrinter::printUsersComment(BlockArgument arg) {
3449 os << "// ";
3450 printValueID(value: arg);
3451 if (arg.use_empty()) {
3452 os << " is unused";
3453 } else {
3454 os << " is used by ";
3455 printValueUsers(value: arg);
3456 }
3457 os << newLine;
3458}
3459
3460void OperationPrinter::printValueUsers(Value value) {
3461 if (value.use_empty())
3462 os << "unused";
3463
3464 // One value might be used as the operand of an operation more than once.
3465 // Only print the operations results once in that case.
3466 SmallPtrSet<Operation *, 1> userSet;
3467 for (auto [index, user] : enumerate(First: value.getUsers())) {
3468 if (userSet.insert(Ptr: user).second)
3469 printUserIDs(user, prefixComma: index);
3470 }
3471}
3472
3473void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
3474 if (prefixComma)
3475 os << ", ";
3476
3477 if (!user->getNumResults()) {
3478 printOperationID(op: user);
3479 } else {
3480 interleaveComma(c: user->getResults(),
3481 eachFn: [this](Value result) { printValueID(value: result); });
3482 }
3483}
3484
3485void OperationPrinter::printCustomOrGenericOp(Operation *op) {
3486 // If requested, always print the generic form.
3487 if (!printerFlags.shouldPrintGenericOpForm()) {
3488 // Check to see if this is a known operation. If so, use the registered
3489 // custom printer hook.
3490 if (auto opInfo = op->getRegisteredInfo()) {
3491 opInfo->printAssembly(op, p&: *this, defaultDialect: defaultDialectStack.back());
3492 return;
3493 }
3494 // Otherwise try to dispatch to the dialect, if available.
3495 if (Dialect *dialect = op->getDialect()) {
3496 if (auto opPrinter = dialect->getOperationPrinter(op)) {
3497 // Print the op name first.
3498 StringRef name = op->getName().getStringRef();
3499 // Only drop the default dialect prefix when it cannot lead to
3500 // ambiguities.
3501 if (name.count(C: '.') == 1)
3502 name.consume_front(Prefix: (defaultDialectStack.back() + ".").str());
3503 os << name;
3504
3505 // Print the rest of the op now.
3506 opPrinter(op, *this);
3507 return;
3508 }
3509 }
3510 }
3511
3512 // Otherwise print with the generic assembly form.
3513 printGenericOp(op, /*printOpName=*/true);
3514}
3515
3516void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
3517 if (printOpName)
3518 printEscapedString(str: op->getName().getStringRef());
3519 os << '(';
3520 interleaveComma(c: op->getOperands(), eachFn: [&](Value value) { printValueID(value); });
3521 os << ')';
3522
3523 // For terminators, print the list of successors and their operands.
3524 if (op->getNumSuccessors() != 0) {
3525 os << '[';
3526 interleaveComma(c: op->getSuccessors(),
3527 eachFn: [&](Block *successor) { printBlockName(block: successor); });
3528 os << ']';
3529 }
3530
3531 // Print the properties.
3532 if (Attribute prop = op->getPropertiesAsAttribute()) {
3533 os << " <";
3534 Impl::printAttribute(attr: prop);
3535 os << '>';
3536 }
3537
3538 // Print regions.
3539 if (op->getNumRegions() != 0) {
3540 os << " (";
3541 interleaveComma(c: op->getRegions(), eachFn: [&](Region &region) {
3542 printRegion(region, /*printEntryBlockArgs=*/true,
3543 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
3544 });
3545 os << ')';
3546 }
3547
3548 printOptionalAttrDict(attrs: op->getPropertiesStorage()
3549 ? llvm::to_vector(op->getDiscardableAttrs())
3550 : op->getAttrs());
3551
3552 // Print the type signature of the operation.
3553 os << " : ";
3554 printFunctionalType(op);
3555}
3556
3557void OperationPrinter::printBlockName(Block *block) {
3558 os << state.getSSANameState().getBlockInfo(block).name;
3559}
3560
3561void OperationPrinter::print(Block *block, bool printBlockArgs,
3562 bool printBlockTerminator) {
3563 // Print the block label and argument list if requested.
3564 if (printBlockArgs) {
3565 os.indent(NumSpaces: currentIndent);
3566 printBlockName(block);
3567
3568 // Print the argument list if non-empty.
3569 if (!block->args_empty()) {
3570 os << '(';
3571 interleaveComma(c: block->getArguments(), eachFn: [&](BlockArgument arg) {
3572 printValueID(value: arg);
3573 os << ": ";
3574 printType(type: arg.getType());
3575 // TODO: We should allow location aliases on block arguments.
3576 printTrailingLocation(loc: arg.getLoc(), /*allowAlias*/ false);
3577 });
3578 os << ')';
3579 }
3580 os << ':';
3581
3582 // Print out some context information about the predecessors of this block.
3583 if (!block->getParent()) {
3584 os << " // block is not in a region!";
3585 } else if (block->hasNoPredecessors()) {
3586 if (!block->isEntryBlock())
3587 os << " // no predecessors";
3588 } else if (auto *pred = block->getSinglePredecessor()) {
3589 os << " // pred: ";
3590 printBlockName(block: pred);
3591 } else {
3592 // We want to print the predecessors in a stable order, not in
3593 // whatever order the use-list is in, so gather and sort them.
3594 SmallVector<BlockInfo, 4> predIDs;
3595 for (auto *pred : block->getPredecessors())
3596 predIDs.push_back(Elt: state.getSSANameState().getBlockInfo(block: pred));
3597 llvm::sort(C&: predIDs, Comp: [](BlockInfo lhs, BlockInfo rhs) {
3598 return lhs.ordering < rhs.ordering;
3599 });
3600
3601 os << " // " << predIDs.size() << " preds: ";
3602
3603 interleaveComma(c: predIDs, eachFn: [&](BlockInfo pred) { os << pred.name; });
3604 }
3605 os << newLine;
3606 }
3607
3608 currentIndent += indentWidth;
3609
3610 if (printerFlags.shouldPrintValueUsers()) {
3611 for (BlockArgument arg : block->getArguments()) {
3612 os.indent(NumSpaces: currentIndent);
3613 printUsersComment(arg);
3614 }
3615 }
3616
3617 bool hasTerminator =
3618 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
3619 auto range = llvm::make_range(
3620 x: block->begin(),
3621 y: std::prev(x: block->end(),
3622 n: (!hasTerminator || printBlockTerminator) ? 0 : 1));
3623 for (auto &op : range) {
3624 printFullOpWithIndentAndLoc(op: &op);
3625 os << newLine;
3626 }
3627 currentIndent -= indentWidth;
3628}
3629
3630void OperationPrinter::printValueID(Value value, bool printResultNo,
3631 raw_ostream *streamOverride) const {
3632 state.getSSANameState().printValueID(value, printResultNo,
3633 stream&: streamOverride ? *streamOverride : os);
3634}
3635
3636void OperationPrinter::printOperationID(Operation *op,
3637 raw_ostream *streamOverride) const {
3638 state.getSSANameState().printOperationID(op, stream&: streamOverride ? *streamOverride
3639 : os);
3640}
3641
3642void OperationPrinter::printSuccessor(Block *successor) {
3643 printBlockName(block: successor);
3644}
3645
3646void OperationPrinter::printSuccessorAndUseList(Block *successor,
3647 ValueRange succOperands) {
3648 printBlockName(block: successor);
3649 if (succOperands.empty())
3650 return;
3651
3652 os << '(';
3653 interleaveComma(c: succOperands,
3654 eachFn: [this](Value operand) { printValueID(value: operand); });
3655 os << " : ";
3656 interleaveComma(c: succOperands,
3657 eachFn: [this](Value operand) { printType(type: operand.getType()); });
3658 os << ')';
3659}
3660
3661void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
3662 bool printBlockTerminators,
3663 bool printEmptyBlock) {
3664 if (printerFlags.shouldSkipRegions()) {
3665 os << "{...}";
3666 return;
3667 }
3668 os << "{" << newLine;
3669 if (!region.empty()) {
3670 auto restoreDefaultDialect =
3671 llvm::make_scope_exit(F: [&]() { defaultDialectStack.pop_back(); });
3672 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
3673 defaultDialectStack.push_back(Elt: iface.getDefaultDialect());
3674 else
3675 defaultDialectStack.push_back(Elt: "");
3676
3677 auto *entryBlock = &region.front();
3678 // Force printing the block header if printEmptyBlock is set and the block
3679 // is empty or if printEntryBlockArgs is set and there are arguments to
3680 // print.
3681 bool shouldAlwaysPrintBlockHeader =
3682 (printEmptyBlock && entryBlock->empty()) ||
3683 (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
3684 print(block: entryBlock, printBlockArgs: shouldAlwaysPrintBlockHeader, printBlockTerminator: printBlockTerminators);
3685 for (auto &b : llvm::drop_begin(RangeOrContainer&: region.getBlocks(), N: 1))
3686 print(block: &b);
3687 }
3688 os.indent(NumSpaces: currentIndent) << "}";
3689}
3690
3691void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3692 ValueRange operands) {
3693 if (!mapAttr) {
3694 os << "<<NULL AFFINE MAP>>";
3695 return;
3696 }
3697 AffineMap map = mapAttr.getValue();
3698 unsigned numDims = map.getNumDims();
3699 auto printValueName = [&](unsigned pos, bool isSymbol) {
3700 unsigned index = isSymbol ? numDims + pos : pos;
3701 assert(index < operands.size());
3702 if (isSymbol)
3703 os << "symbol(";
3704 printValueID(value: operands[index]);
3705 if (isSymbol)
3706 os << ')';
3707 };
3708
3709 interleaveComma(c: map.getResults(), eachFn: [&](AffineExpr expr) {
3710 printAffineExpr(expr, printValueName);
3711 });
3712}
3713
3714void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
3715 ValueRange dimOperands,
3716 ValueRange symOperands) {
3717 auto printValueName = [&](unsigned pos, bool isSymbol) {
3718 if (!isSymbol)
3719 return printValueID(value: dimOperands[pos]);
3720 os << "symbol(";
3721 printValueID(value: symOperands[pos]);
3722 os << ')';
3723 };
3724 printAffineExpr(expr, printValueName);
3725}
3726
3727//===----------------------------------------------------------------------===//
3728// print and dump methods
3729//===----------------------------------------------------------------------===//
3730
3731void Attribute::print(raw_ostream &os, bool elideType) const {
3732 if (!*this) {
3733 os << "<<NULL ATTRIBUTE>>";
3734 return;
3735 }
3736
3737 AsmState state(getContext());
3738 print(os, state, elideType);
3739}
3740void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const {
3741 using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision;
3742 AsmPrinter::Impl(os, state.getImpl())
3743 .printAttribute(attr: *this, typeElision: elideType ? AttrTypeElision::Must
3744 : AttrTypeElision::Never);
3745}
3746
3747void Attribute::dump() const {
3748 print(os&: llvm::errs());
3749 llvm::errs() << "\n";
3750}
3751
3752void Attribute::printStripped(raw_ostream &os, AsmState &state) const {
3753 if (!*this) {
3754 os << "<<NULL ATTRIBUTE>>";
3755 return;
3756 }
3757
3758 AsmPrinter::Impl subPrinter(os, state.getImpl());
3759 if (succeeded(result: subPrinter.printAlias(attr: *this)))
3760 return;
3761
3762 auto &dialect = this->getDialect();
3763 uint64_t posPrior = os.tell();
3764 DialectAsmPrinter printer(subPrinter);
3765 dialect.printAttribute(*this, printer);
3766 if (posPrior != os.tell())
3767 return;
3768
3769 // Fallback to printing with prefix if the above failed to write anything
3770 // to the output stream.
3771 print(os, state);
3772}
3773void Attribute::printStripped(raw_ostream &os) const {
3774 if (!*this) {
3775 os << "<<NULL ATTRIBUTE>>";
3776 return;
3777 }
3778
3779 AsmState state(getContext());
3780 printStripped(os, state);
3781}
3782
3783void Type::print(raw_ostream &os) const {
3784 if (!*this) {
3785 os << "<<NULL TYPE>>";
3786 return;
3787 }
3788
3789 AsmState state(getContext());
3790 print(os, state);
3791}
3792void Type::print(raw_ostream &os, AsmState &state) const {
3793 AsmPrinter::Impl(os, state.getImpl()).printType(type: *this);
3794}
3795
3796void Type::dump() const {
3797 print(os&: llvm::errs());
3798 llvm::errs() << "\n";
3799}
3800
3801void AffineMap::dump() const {
3802 print(os&: llvm::errs());
3803 llvm::errs() << "\n";
3804}
3805
3806void IntegerSet::dump() const {
3807 print(os&: llvm::errs());
3808 llvm::errs() << "\n";
3809}
3810
3811void AffineExpr::print(raw_ostream &os) const {
3812 if (!expr) {
3813 os << "<<NULL AFFINE EXPR>>";
3814 return;
3815 }
3816 AsmState state(getContext());
3817 AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(expr: *this);
3818}
3819
3820void AffineExpr::dump() const {
3821 print(os&: llvm::errs());
3822 llvm::errs() << "\n";
3823}
3824
3825void AffineMap::print(raw_ostream &os) const {
3826 if (!map) {
3827 os << "<<NULL AFFINE MAP>>";
3828 return;
3829 }
3830 AsmState state(getContext());
3831 AsmPrinter::Impl(os, state.getImpl()).printAffineMap(map: *this);
3832}
3833
3834void IntegerSet::print(raw_ostream &os) const {
3835 AsmState state(getContext());
3836 AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(set: *this);
3837}
3838
3839void Value::print(raw_ostream &os) const { print(os, flags: OpPrintingFlags()); }
3840void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const {
3841 if (!impl) {
3842 os << "<<NULL VALUE>>";
3843 return;
3844 }
3845
3846 if (auto *op = getDefiningOp())
3847 return op->print(os, flags);
3848 // TODO: Improve BlockArgument print'ing.
3849 BlockArgument arg = llvm::cast<BlockArgument>(Val: *this);
3850 os << "<block argument> of type '" << arg.getType()
3851 << "' at index: " << arg.getArgNumber();
3852}
3853void Value::print(raw_ostream &os, AsmState &state) const {
3854 if (!impl) {
3855 os << "<<NULL VALUE>>";
3856 return;
3857 }
3858
3859 if (auto *op = getDefiningOp())
3860 return op->print(os, state);
3861
3862 // TODO: Improve BlockArgument print'ing.
3863 BlockArgument arg = llvm::cast<BlockArgument>(Val: *this);
3864 os << "<block argument> of type '" << arg.getType()
3865 << "' at index: " << arg.getArgNumber();
3866}
3867
3868void Value::dump() const {
3869 print(os&: llvm::errs());
3870 llvm::errs() << "\n";
3871}
3872
3873void Value::printAsOperand(raw_ostream &os, AsmState &state) const {
3874 // TODO: This doesn't necessarily capture all potential cases.
3875 // Currently, region arguments can be shadowed when printing the main
3876 // operation. If the IR hasn't been printed, this will produce the old SSA
3877 // name and not the shadowed name.
3878 state.getImpl().getSSANameState().printValueID(value: *this, /*printResultNo=*/true,
3879 stream&: os);
3880}
3881
3882static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
3883 do {
3884 // If we are printing local scope, stop at the first operation that is
3885 // isolated from above.
3886 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
3887 break;
3888
3889 // Otherwise, traverse up to the next parent.
3890 Operation *parentOp = op->getParentOp();
3891 if (!parentOp)
3892 break;
3893 op = parentOp;
3894 } while (true);
3895 return op;
3896}
3897
3898void Value::printAsOperand(raw_ostream &os,
3899 const OpPrintingFlags &flags) const {
3900 Operation *op;
3901 if (auto result = llvm::dyn_cast<OpResult>(Val: *this)) {
3902 op = result.getOwner();
3903 } else {
3904 op = llvm::cast<BlockArgument>(Val: *this).getOwner()->getParentOp();
3905 if (!op) {
3906 os << "<<UNKNOWN SSA VALUE>>";
3907 return;
3908 }
3909 }
3910 op = findParent(op, shouldUseLocalScope: flags.shouldUseLocalScope());
3911 AsmState state(op, flags);
3912 printAsOperand(os, state);
3913}
3914
3915void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
3916 // Find the operation to number from based upon the provided flags.
3917 Operation *op = findParent(op: this, shouldUseLocalScope: printerFlags.shouldUseLocalScope());
3918 AsmState state(op, printerFlags);
3919 print(os, state);
3920}
3921void Operation::print(raw_ostream &os, AsmState &state) {
3922 OperationPrinter printer(os, state.getImpl());
3923 if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
3924 state.getImpl().initializeAliases(op: this);
3925 printer.printTopLevelOperation(op: this);
3926 } else {
3927 printer.printFullOpWithIndentAndLoc(op: this);
3928 }
3929}
3930
3931void Operation::dump() {
3932 print(os&: llvm::errs(), printerFlags: OpPrintingFlags().useLocalScope());
3933 llvm::errs() << "\n";
3934}
3935
3936void Block::print(raw_ostream &os) {
3937 Operation *parentOp = getParentOp();
3938 if (!parentOp) {
3939 os << "<<UNLINKED BLOCK>>\n";
3940 return;
3941 }
3942 // Get the top-level op.
3943 while (auto *nextOp = parentOp->getParentOp())
3944 parentOp = nextOp;
3945
3946 AsmState state(parentOp);
3947 print(os, state);
3948}
3949void Block::print(raw_ostream &os, AsmState &state) {
3950 OperationPrinter(os, state.getImpl()).print(block: this);
3951}
3952
3953void Block::dump() { print(os&: llvm::errs()); }
3954
3955/// Print out the name of the block without printing its body.
3956void Block::printAsOperand(raw_ostream &os, bool printType) {
3957 Operation *parentOp = getParentOp();
3958 if (!parentOp) {
3959 os << "<<UNLINKED BLOCK>>\n";
3960 return;
3961 }
3962 AsmState state(parentOp);
3963 printAsOperand(os, state);
3964}
3965void Block::printAsOperand(raw_ostream &os, AsmState &state) {
3966 OperationPrinter printer(os, state.getImpl());
3967 printer.printBlockName(block: this);
3968}
3969
3970//===--------------------------------------------------------------------===//
3971// Custom printers
3972//===--------------------------------------------------------------------===//
3973namespace mlir {
3974
3975void printDimensionList(OpAsmPrinter &printer, Operation *op,
3976 ArrayRef<int64_t> dimensions) {
3977 if (dimensions.empty())
3978 printer << "[";
3979 printer.printDimensionList(shape: dimensions);
3980 if (dimensions.empty())
3981 printer << "]";
3982}
3983
3984ParseResult parseDimensionList(OpAsmParser &parser,
3985 DenseI64ArrayAttr &dimensions) {
3986 // Empty list case denoted by "[]".
3987 if (succeeded(result: parser.parseOptionalLSquare())) {
3988 if (failed(result: parser.parseRSquare())) {
3989 return parser.emitError(loc: parser.getCurrentLocation())
3990 << "Failed parsing dimension list.";
3991 }
3992 dimensions =
3993 DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>());
3994 return success();
3995 }
3996
3997 // Non-empty list case.
3998 SmallVector<int64_t> shapeArr;
3999 if (failed(result: parser.parseDimensionList(dimensions&: shapeArr, allowDynamic: true, withTrailingX: false))) {
4000 return parser.emitError(loc: parser.getCurrentLocation())
4001 << "Failed parsing dimension list.";
4002 }
4003 if (shapeArr.empty()) {
4004 return parser.emitError(loc: parser.getCurrentLocation())
4005 << "Failed parsing dimension list. Did you mean an empty list? It "
4006 "must be denoted by \"[]\".";
4007 }
4008 dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
4009 return success();
4010}
4011
4012} // namespace mlir
4013

source code of mlir/lib/IR/AsmPrinter.cpp