1 | //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Pattern wrapper class to simplify using TableGen Record defining a MLIR |
10 | // Pattern. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_TABLEGEN_PATTERN_H_ |
15 | #define MLIR_TABLEGEN_PATTERN_H_ |
16 | |
17 | #include "mlir/Support/LLVM.h" |
18 | #include "mlir/TableGen/Argument.h" |
19 | #include "mlir/TableGen/Operator.h" |
20 | #include "llvm/ADT/DenseMap.h" |
21 | #include "llvm/ADT/Hashing.h" |
22 | #include "llvm/ADT/StringMap.h" |
23 | #include "llvm/ADT/StringSet.h" |
24 | |
25 | #include <optional> |
26 | #include <unordered_map> |
27 | |
28 | namespace llvm { |
29 | class DagInit; |
30 | class Init; |
31 | class Record; |
32 | } // namespace llvm |
33 | |
34 | namespace mlir { |
35 | namespace tblgen { |
36 | |
37 | // Mapping from TableGen Record to Operator wrapper object. |
38 | // |
39 | // We allocate each wrapper object in heap to make sure the pointer to it is |
40 | // valid throughout the lifetime of this map. This is important because this map |
41 | // is shared among multiple patterns to avoid creating the wrapper object for |
42 | // the same op again and again. But this map will continuously grow. |
43 | using RecordOperatorMap = |
44 | DenseMap<const llvm::Record *, std::unique_ptr<Operator>>; |
45 | |
46 | class Pattern; |
47 | |
48 | // Wrapper class providing helper methods for accessing TableGen DAG leaves |
49 | // used inside Patterns. This class is lightweight and designed to be used like |
50 | // values. |
51 | // |
52 | // A TableGen DAG construct is of the syntax |
53 | // `(operator, arg0, arg1, ...)`. |
54 | // |
55 | // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects |
56 | // for handy helper methods. It only works on `arg*`s that are not nested DAG |
57 | // constructs. |
58 | class DagLeaf { |
59 | public: |
60 | explicit DagLeaf(const llvm::Init *def) : def(def) {} |
61 | |
62 | // Returns true if this DAG leaf is not specified in the pattern. That is, it |
63 | // places no further constraints/transforms and just carries over the original |
64 | // value. |
65 | bool isUnspecified() const; |
66 | |
67 | // Returns true if this DAG leaf is matching an operand. That is, it specifies |
68 | // a type constraint. |
69 | bool isOperandMatcher() const; |
70 | |
71 | // Returns true if this DAG leaf is matching an attribute. That is, it |
72 | // specifies an attribute constraint. |
73 | bool isAttrMatcher() const; |
74 | |
75 | // Returns true if this DAG leaf is wrapping native code call. |
76 | bool isNativeCodeCall() const; |
77 | |
78 | // Returns true if this DAG leaf is specifying a constant attribute. |
79 | bool isConstantAttr() const; |
80 | |
81 | // Returns true if this DAG leaf is specifying an enum attribute case. |
82 | bool isEnumAttrCase() const; |
83 | |
84 | // Returns true if this DAG leaf is specifying a string attribute. |
85 | bool isStringAttr() const; |
86 | |
87 | // Returns this DAG leaf as a constraint. Asserts if fails. |
88 | Constraint getAsConstraint() const; |
89 | |
90 | // Returns this DAG leaf as an constant attribute. Asserts if fails. |
91 | ConstantAttr getAsConstantAttr() const; |
92 | |
93 | // Returns this DAG leaf as an enum attribute case. |
94 | // Precondition: isEnumAttrCase() |
95 | EnumAttrCase getAsEnumAttrCase() const; |
96 | |
97 | // Returns the matching condition template inside this DAG leaf. Assumes the |
98 | // leaf is an operand/attribute matcher and asserts otherwise. |
99 | std::string getConditionTemplate() const; |
100 | |
101 | // Returns the native code call template inside this DAG leaf. |
102 | // Precondition: isNativeCodeCall() |
103 | StringRef getNativeCodeTemplate() const; |
104 | |
105 | // Returns the number of values will be returned by the native helper |
106 | // function. |
107 | // Precondition: isNativeCodeCall() |
108 | int getNumReturnsOfNativeCode() const; |
109 | |
110 | // Returns the string associated with the leaf. |
111 | // Precondition: isStringAttr() |
112 | std::string getStringAttr() const; |
113 | |
114 | void print(raw_ostream &os) const; |
115 | |
116 | private: |
117 | friend llvm::DenseMapInfo<DagLeaf>; |
118 | const void *getAsOpaquePointer() const { return def; } |
119 | |
120 | // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and |
121 | // also a subclass of the given `superclass`. |
122 | bool isSubClassOf(StringRef superclass) const; |
123 | |
124 | const llvm::Init *def; |
125 | }; |
126 | |
127 | // Wrapper class providing helper methods for accessing TableGen DAG constructs |
128 | // used inside Patterns. This class is lightweight and designed to be used like |
129 | // values. |
130 | // |
131 | // A TableGen DAG construct is of the syntax |
132 | // `(operator, arg0, arg1, ...)`. |
133 | // |
134 | // When used inside Patterns, `operator` corresponds to some dialect op, or |
135 | // a known list of verbs that defines special transformation actions. This |
136 | // `arg*` can be a nested DAG construct. This class provides getters to |
137 | // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper |
138 | // methods. |
139 | // |
140 | // A null DagNode contains a nullptr and converts to false implicitly. |
141 | class DagNode { |
142 | public: |
143 | explicit DagNode(const llvm::DagInit *node) : node(node) {} |
144 | |
145 | // Implicit bool converter that returns true if this DagNode is not a null |
146 | // DagNode. |
147 | operator bool() const { return node != nullptr; } |
148 | |
149 | // Returns the symbol bound to this DAG node. |
150 | StringRef getSymbol() const; |
151 | |
152 | // Returns the operator wrapper object corresponding to the dialect op matched |
153 | // by this DAG. The operator wrapper will be queried from the given `mapper` |
154 | // and created in it if not existing. |
155 | Operator &getDialectOp(RecordOperatorMap *mapper) const; |
156 | |
157 | // Returns the number of operations recursively involved in the DAG tree |
158 | // rooted from this node. |
159 | int getNumOps() const; |
160 | |
161 | // Returns the number of immediate arguments to this DAG node. |
162 | int getNumArgs() const; |
163 | |
164 | // Returns true if the `index`-th argument is a nested DAG construct. |
165 | bool isNestedDagArg(unsigned index) const; |
166 | |
167 | // Gets the `index`-th argument as a nested DAG construct if possible. Returns |
168 | // null DagNode otherwise. |
169 | DagNode getArgAsNestedDag(unsigned index) const; |
170 | |
171 | // Gets the `index`-th argument as a DAG leaf. |
172 | DagLeaf getArgAsLeaf(unsigned index) const; |
173 | |
174 | // Returns the specified name of the `index`-th argument. |
175 | StringRef getArgName(unsigned index) const; |
176 | |
177 | // Returns true if this DAG construct means to replace with an existing SSA |
178 | // value. |
179 | bool isReplaceWithValue() const; |
180 | |
181 | // Returns whether this DAG represents the location of an op creation. |
182 | bool isLocationDirective() const; |
183 | |
184 | // Returns whether this DAG is a return type specifier. |
185 | bool isReturnTypeDirective() const; |
186 | |
187 | // Returns true if this DAG node is wrapping native code call. |
188 | bool isNativeCodeCall() const; |
189 | |
190 | // Returns whether this DAG is an `either` specifier. |
191 | bool isEither() const; |
192 | |
193 | // Returns whether this DAG is an `variadic` specifier. |
194 | bool isVariadic() const; |
195 | |
196 | // Returns true if this DAG node is an operation. |
197 | bool isOperation() const; |
198 | |
199 | // Returns the native code call template inside this DAG node. |
200 | // Precondition: isNativeCodeCall() |
201 | StringRef getNativeCodeTemplate() const; |
202 | |
203 | // Returns the number of values will be returned by the native helper |
204 | // function. |
205 | // Precondition: isNativeCodeCall() |
206 | int getNumReturnsOfNativeCode() const; |
207 | |
208 | void print(raw_ostream &os) const; |
209 | |
210 | private: |
211 | friend class SymbolInfoMap; |
212 | friend llvm::DenseMapInfo<DagNode>; |
213 | const void *getAsOpaquePointer() const { return node; } |
214 | |
215 | const llvm::DagInit *node; // nullptr means null DagNode |
216 | }; |
217 | |
218 | // A class for maintaining information for symbols bound in patterns and |
219 | // provides methods for resolving them according to specific use cases. |
220 | // |
221 | // Symbols can be bound to |
222 | // |
223 | // * Op arguments and op results in the source pattern and |
224 | // * Op results in result patterns. |
225 | // |
226 | // Symbols can be referenced in result patterns and additional constraints to |
227 | // the pattern. |
228 | // |
229 | // For example, in |
230 | // |
231 | // ``` |
232 | // def : Pattern< |
233 | // (SrcOp:$results1 $arg0, %arg1), |
234 | // [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>; |
235 | // ``` |
236 | // |
237 | // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to |
238 | // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build |
239 | // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`. |
240 | // |
241 | // If a symbol binds to a multi-result op and it does not have the `__N` |
242 | // suffix, the symbol is expanded to represent all results generated by the |
243 | // multi-result op. If the symbol has a `__N` suffix, then it will expand to |
244 | // only the N-th *static* result as declared in ODS, and that can still |
245 | // corresponds to multiple *dynamic* values if the N-th *static* result is |
246 | // variadic. |
247 | // |
248 | // This class keeps track of such symbols and resolves them into their bound |
249 | // values in a suitable way. |
250 | class SymbolInfoMap { |
251 | public: |
252 | explicit SymbolInfoMap(ArrayRef<SMLoc> loc) : loc(loc) {} |
253 | |
254 | // Class for information regarding a symbol. |
255 | class SymbolInfo { |
256 | public: |
257 | // Returns a type string of a variable. |
258 | std::string getVarTypeStr(StringRef name) const; |
259 | |
260 | // Returns a string for defining a variable named as `name` to store the |
261 | // value bound by this symbol. |
262 | std::string getVarDecl(StringRef name) const; |
263 | |
264 | // Returns a string for defining an argument which passes the reference of |
265 | // the variable. |
266 | std::string getArgDecl(StringRef name) const; |
267 | |
268 | // Returns a variable name for the symbol named as `name`. |
269 | std::string getVarName(StringRef name) const; |
270 | |
271 | private: |
272 | // Allow SymbolInfoMap to access private methods. |
273 | friend class SymbolInfoMap; |
274 | |
275 | // Structure to uniquely distinguish different locations of the symbols. |
276 | // |
277 | // * If a symbol is defined as an operand of an operation, `dag` specifies |
278 | // the DAG of the operation, `operandIndexOrNumValues` specifies the |
279 | // operand index, and `variadicSubIndex` must be set to `std::nullopt`. |
280 | // |
281 | // * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG |
282 | // of the parent operation, `operandIndexOrNumValues` specifies the |
283 | // declared operand index of the variadic operand in the parent |
284 | // operation. |
285 | // |
286 | // - If the symbol is defined as a result of `variadic` DAG, the |
287 | // `variadicSubIndex` must be set to `std::nullopt`, which means that |
288 | // the symbol binds to the full operand range. |
289 | // |
290 | // - If the symbol is defined as a operand, the `variadicSubIndex` must |
291 | // be set to the index within the variadic sub-operand list. |
292 | // |
293 | // * If a symbol is defined in a `either` DAG, `dag` specifies the DAG |
294 | // of the parent operation, `operandIndexOrNumValues` specifies the |
295 | // operand index in the parent operation (not necessary the index in the |
296 | // DAG). |
297 | // |
298 | // * If a symbol is defined as a result, specifies the number of returning |
299 | // value. |
300 | // |
301 | // Example 1: |
302 | // |
303 | // def : Pat<(OpA $input0, $input1), ...>; |
304 | // |
305 | // $input0: (OpA, 0, nullopt) |
306 | // $input1: (OpA, 1, nullopt) |
307 | // |
308 | // Example 2: |
309 | // |
310 | // def : Pat<(OpB (variadic:$input0 $input0a, $input0b), |
311 | // (variadic:$input1 $input1a, $input1b, $input1c)), |
312 | // ...>; |
313 | // |
314 | // $input0: (OpB, 0, nullopt) |
315 | // $input0a: (OpB, 0, 0) |
316 | // $input0b: (OpB, 0, 1) |
317 | // $input1: (OpB, 1, nullopt) |
318 | // $input1a: (OpB, 1, 0) |
319 | // $input1b: (OpB, 1, 1) |
320 | // $input1c: (OpB, 1, 2) |
321 | // |
322 | // Example 3: |
323 | // |
324 | // def : Pat<(OpC $input0, (either $input1, $input2)), ...>; |
325 | // |
326 | // $input0: (OpC, 0, nullopt) |
327 | // $input1: (OpC, 1, nullopt) |
328 | // $input2: (OpC, 2, nullopt) |
329 | // |
330 | // Example 4: |
331 | // |
332 | // def ThreeResultOp : TEST_Op<...> { |
333 | // let results = (outs |
334 | // AnyType:$result1, |
335 | // AnyType:$result2, |
336 | // AnyType:$result3 |
337 | // ); |
338 | // } |
339 | // |
340 | // def : Pat<..., |
341 | // (ThreeResultOp:$result ...)>; |
342 | // |
343 | // $result: (nullptr, 3, nullopt) |
344 | // |
345 | struct DagAndConstant { |
346 | // DagNode and DagLeaf are accessed by value which means it can't be used |
347 | // as identifier here. Use an opaque pointer type instead. |
348 | const void *dag; |
349 | int operandIndexOrNumValues; |
350 | std::optional<int> variadicSubIndex; |
351 | |
352 | DagAndConstant(const void *dag, int operandIndexOrNumValues, |
353 | std::optional<int> variadicSubIndex) |
354 | : dag(dag), operandIndexOrNumValues(operandIndexOrNumValues), |
355 | variadicSubIndex(variadicSubIndex) {} |
356 | |
357 | bool operator==(const DagAndConstant &rhs) const { |
358 | return dag == rhs.dag && |
359 | operandIndexOrNumValues == rhs.operandIndexOrNumValues && |
360 | variadicSubIndex == rhs.variadicSubIndex; |
361 | } |
362 | }; |
363 | |
364 | // What kind of entity this symbol represents: |
365 | // * Attr: op attribute |
366 | // * Operand: op operand |
367 | // * Result: op result |
368 | // * Value: a value not attached to an op (e.g., from NativeCodeCall) |
369 | // * MultipleValues: a pack of values not attached to an op (e.g., from |
370 | // NativeCodeCall). This kind supports indexing. |
371 | enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues }; |
372 | |
373 | // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr` |
374 | // and `Operand` so should be std::nullopt for `Result` and `Value` kind. |
375 | SymbolInfo(const Operator *op, Kind kind, |
376 | std::optional<DagAndConstant> dagAndConstant); |
377 | |
378 | // Static methods for creating SymbolInfo. |
379 | static SymbolInfo getAttr(const Operator *op, int index) { |
380 | return SymbolInfo(op, Kind::Attr, |
381 | DagAndConstant(nullptr, index, std::nullopt)); |
382 | } |
383 | static SymbolInfo getAttr() { |
384 | return SymbolInfo(nullptr, Kind::Attr, std::nullopt); |
385 | } |
386 | static SymbolInfo |
387 | getOperand(DagNode node, const Operator *op, int operandIndex, |
388 | std::optional<int> variadicSubIndex = std::nullopt) { |
389 | return SymbolInfo(op, Kind::Operand, |
390 | DagAndConstant(node.getAsOpaquePointer(), operandIndex, |
391 | variadicSubIndex)); |
392 | } |
393 | static SymbolInfo getResult(const Operator *op) { |
394 | return SymbolInfo(op, Kind::Result, std::nullopt); |
395 | } |
396 | static SymbolInfo getValue() { |
397 | return SymbolInfo(nullptr, Kind::Value, std::nullopt); |
398 | } |
399 | static SymbolInfo getMultipleValues(int numValues) { |
400 | return SymbolInfo(nullptr, Kind::MultipleValues, |
401 | DagAndConstant(nullptr, numValues, std::nullopt)); |
402 | } |
403 | |
404 | // Returns the number of static values this symbol corresponds to. |
405 | // A static value is an operand/result declared in ODS. Normally a symbol |
406 | // only represents one static value, but symbols bound to op results can |
407 | // represent more than one if the op is a multi-result op. |
408 | int getStaticValueCount() const; |
409 | |
410 | // Returns a string containing the C++ expression for referencing this |
411 | // symbol as a value (if this symbol represents one static value) or a value |
412 | // range (if this symbol represents multiple static values). `name` is the |
413 | // name of the C++ variable that this symbol bounds to. `index` should only |
414 | // be used for indexing results. `fmt` is used to format each value. |
415 | // `separator` is used to separate values if this is a value range. |
416 | std::string getValueAndRangeUse(StringRef name, int index, const char *fmt, |
417 | const char *separator) const; |
418 | |
419 | // Returns a string containing the C++ expression for referencing this |
420 | // symbol as a value range regardless of how many static values this symbol |
421 | // represents. `name` is the name of the C++ variable that this symbol |
422 | // bounds to. `index` should only be used for indexing results. `fmt` is |
423 | // used to format each value. `separator` is used to separate values in the |
424 | // range. |
425 | std::string getAllRangeUse(StringRef name, int index, const char *fmt, |
426 | const char *separator) const; |
427 | |
428 | // The argument index (for `Attr` and `Operand` only) |
429 | int getArgIndex() const { return dagAndConstant->operandIndexOrNumValues; } |
430 | |
431 | // The number of values in the MultipleValue |
432 | int getSize() const { return dagAndConstant->operandIndexOrNumValues; } |
433 | |
434 | // The variadic sub-operands index (for variadic `Operand` only) |
435 | std::optional<int> getVariadicSubIndex() const { |
436 | return dagAndConstant->variadicSubIndex; |
437 | } |
438 | |
439 | const Operator *op; // The op where the bound entity belongs |
440 | Kind kind; // The kind of the bound entity |
441 | |
442 | // The tuple of DagNode pointer and two constant values (for `Attr`, |
443 | // `Operand` and the size of MultipleValue symbol). Note that operands may |
444 | // be bound to the same symbol, use the DagNode and index to distinguish |
445 | // them. For `Attr` and MultipleValue, the Dag part will be nullptr. |
446 | std::optional<DagAndConstant> dagAndConstant; |
447 | |
448 | // Alternative name for the symbol. It is used in case the name |
449 | // is not unique. Applicable for `Operand` only. |
450 | std::optional<std::string> alternativeName; |
451 | }; |
452 | |
453 | using BaseT = std::unordered_multimap<std::string, SymbolInfo>; |
454 | |
455 | // Iterators for accessing all symbols. |
456 | using iterator = BaseT::iterator; |
457 | iterator begin() { return symbolInfoMap.begin(); } |
458 | iterator end() { return symbolInfoMap.end(); } |
459 | |
460 | // Const iterators for accessing all symbols. |
461 | using const_iterator = BaseT::const_iterator; |
462 | const_iterator begin() const { return symbolInfoMap.begin(); } |
463 | const_iterator end() const { return symbolInfoMap.end(); } |
464 | |
465 | // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. |
466 | // Returns false if `symbol` is already bound and symbols are not operands. |
467 | bool bindOpArgument(DagNode node, StringRef symbol, const Operator &op, |
468 | int argIndex, |
469 | std::optional<int> variadicSubIndex = std::nullopt); |
470 | |
471 | // Binds the given `symbol` to the results the given `op`. Returns false if |
472 | // `symbol` is already bound. |
473 | bool bindOpResult(StringRef symbol, const Operator &op); |
474 | |
475 | // A helper function for dispatching target value binding functions. |
476 | bool bindValues(StringRef symbol, int numValues = 1); |
477 | |
478 | // Registers the given `symbol` as bound to the Value(s). Returns false if |
479 | // `symbol` is already bound. |
480 | bool bindValue(StringRef symbol); |
481 | |
482 | // Registers the given `symbol` as bound to a MultipleValue. Return false if |
483 | // `symbol` is already bound. |
484 | bool bindMultipleValues(StringRef symbol, int numValues); |
485 | |
486 | // Registers the given `symbol` as bound to an attr. Returns false if `symbol` |
487 | // is already bound. |
488 | bool bindAttr(StringRef symbol); |
489 | |
490 | // Returns true if the given `symbol` is bound. |
491 | bool contains(StringRef symbol) const; |
492 | |
493 | // Returns an iterator to the information of the given symbol named as `key`. |
494 | const_iterator find(StringRef key) const; |
495 | |
496 | // Returns an iterator to the information of the given symbol named as `key`, |
497 | // with index `argIndex` for operator `op`. |
498 | const_iterator findBoundSymbol(StringRef key, DagNode node, |
499 | const Operator &op, int argIndex, |
500 | std::optional<int> variadicSubIndex) const; |
501 | const_iterator findBoundSymbol(StringRef key, |
502 | const SymbolInfo &symbolInfo) const; |
503 | |
504 | // Returns the bounds of a range that includes all the elements which |
505 | // bind to the `key`. |
506 | std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key); |
507 | |
508 | // Returns number of times symbol named as `key` was used. |
509 | int count(StringRef key) const; |
510 | |
511 | // Returns the number of static values of the given `symbol` corresponds to. |
512 | // A static value is an operand/result declared in ODS. Normally a symbol only |
513 | // represents one static value, but symbols bound to op results can represent |
514 | // more than one if the op is a multi-result op. |
515 | int getStaticValueCount(StringRef symbol) const; |
516 | |
517 | // Returns a string containing the C++ expression for referencing this |
518 | // symbol as a value (if this symbol represents one static value) or a value |
519 | // range (if this symbol represents multiple static values). `fmt` is used to |
520 | // format each value. `separator` is used to separate values if `symbol` |
521 | // represents a value range. |
522 | std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}" , |
523 | const char *separator = ", " ) const; |
524 | |
525 | // Returns a string containing the C++ expression for referencing this |
526 | // symbol as a value range regardless of how many static values this symbol |
527 | // represents. `fmt` is used to format each value. `separator` is used to |
528 | // separate values in the range. |
529 | std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}" , |
530 | const char *separator = ", " ) const; |
531 | |
532 | // Assign alternative unique names to Operands that have equal names. |
533 | void assignUniqueAlternativeNames(); |
534 | |
535 | // Splits the given `symbol` into a value pack name and an index. Returns the |
536 | // value pack name and writes the index to `index` on success. Returns |
537 | // `symbol` itself if it does not contain an index. |
538 | // |
539 | // We can use `name__N` to access the `N`-th value in the value pack bound to |
540 | // `name`. `name` is typically the results of an multi-result op. |
541 | static StringRef getValuePackName(StringRef symbol, int *index = nullptr); |
542 | |
543 | private: |
544 | BaseT symbolInfoMap; |
545 | |
546 | // Pattern instantiation location. This is intended to be used as parameter |
547 | // to PrintFatalError() to report errors. |
548 | ArrayRef<SMLoc> loc; |
549 | }; |
550 | |
551 | // Wrapper class providing helper methods for accessing MLIR Pattern defined |
552 | // in TableGen. This class should closely reflect what is defined as class |
553 | // `Pattern` in TableGen. This class contains maps so it is not intended to be |
554 | // used as values. |
555 | class Pattern { |
556 | public: |
557 | explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper); |
558 | |
559 | // Returns the source pattern to match. |
560 | DagNode getSourcePattern() const; |
561 | |
562 | // Returns the number of result patterns generated by applying this rewrite |
563 | // rule. |
564 | int getNumResultPatterns() const; |
565 | |
566 | // Returns the DAG tree root node of the `index`-th result pattern. |
567 | DagNode getResultPattern(unsigned index) const; |
568 | |
569 | // Collects all symbols bound in the source pattern into `infoMap`. |
570 | void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap); |
571 | |
572 | // Collects all symbols bound in result patterns into `infoMap`. |
573 | void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap); |
574 | |
575 | // Returns the op that the root node of the source pattern matches. |
576 | const Operator &getSourceRootOp(); |
577 | |
578 | // Returns the operator wrapper object corresponding to the given `node`'s DAG |
579 | // operator. |
580 | Operator &getDialectOp(DagNode node); |
581 | |
582 | // Returns the constraints. |
583 | std::vector<AppliedConstraint> getConstraints() const; |
584 | |
585 | // Returns the number of supplemental auxiliary patterns generated by applying |
586 | // this rewrite rule. |
587 | int getNumSupplementalPatterns() const; |
588 | |
589 | // Returns the DAG tree root node of the `index`-th supplemental result |
590 | // pattern. |
591 | DagNode getSupplementalPattern(unsigned index) const; |
592 | |
593 | // Returns the benefit score of the pattern. |
594 | int getBenefit() const; |
595 | |
596 | using IdentifierLine = std::pair<StringRef, unsigned>; |
597 | |
598 | // Returns the file location of the pattern (buffer identifier + line number |
599 | // pair). |
600 | std::vector<IdentifierLine> getLocation() const; |
601 | |
602 | // Recursively collects all bound symbols inside the DAG tree rooted |
603 | // at `tree` and updates the given `infoMap`. |
604 | void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, |
605 | bool isSrcPattern); |
606 | |
607 | private: |
608 | // Helper function to verify variable binding. |
609 | void verifyBind(bool result, StringRef symbolName); |
610 | |
611 | // The TableGen definition of this pattern. |
612 | const llvm::Record &def; |
613 | |
614 | // All operators. |
615 | // TODO: we need a proper context manager, like MLIRContext, for managing the |
616 | // lifetime of shared entities. |
617 | RecordOperatorMap *recordOpMap; |
618 | }; |
619 | |
620 | } // namespace tblgen |
621 | } // namespace mlir |
622 | |
623 | namespace llvm { |
624 | template <> |
625 | struct DenseMapInfo<mlir::tblgen::DagNode> { |
626 | static mlir::tblgen::DagNode getEmptyKey() { |
627 | return mlir::tblgen::DagNode( |
628 | llvm::DenseMapInfo<llvm::DagInit *>::getEmptyKey()); |
629 | } |
630 | static mlir::tblgen::DagNode getTombstoneKey() { |
631 | return mlir::tblgen::DagNode( |
632 | llvm::DenseMapInfo<llvm::DagInit *>::getTombstoneKey()); |
633 | } |
634 | static unsigned getHashValue(mlir::tblgen::DagNode node) { |
635 | return llvm::hash_value(ptr: node.getAsOpaquePointer()); |
636 | } |
637 | static bool isEqual(mlir::tblgen::DagNode lhs, mlir::tblgen::DagNode rhs) { |
638 | return lhs.node == rhs.node; |
639 | } |
640 | }; |
641 | |
642 | template <> |
643 | struct DenseMapInfo<mlir::tblgen::DagLeaf> { |
644 | static mlir::tblgen::DagLeaf getEmptyKey() { |
645 | return mlir::tblgen::DagLeaf( |
646 | llvm::DenseMapInfo<llvm::Init *>::getEmptyKey()); |
647 | } |
648 | static mlir::tblgen::DagLeaf getTombstoneKey() { |
649 | return mlir::tblgen::DagLeaf( |
650 | llvm::DenseMapInfo<llvm::Init *>::getTombstoneKey()); |
651 | } |
652 | static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) { |
653 | return llvm::hash_value(ptr: leaf.getAsOpaquePointer()); |
654 | } |
655 | static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) { |
656 | return lhs.def == rhs.def; |
657 | } |
658 | }; |
659 | } // namespace llvm |
660 | |
661 | #endif // MLIR_TABLEGEN_PATTERN_H_ |
662 | |