1 | //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file provides a simple and efficient mechanism for performing general |
10 | // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's |
11 | // include/llvm/IR/PatternMatch.h. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #ifndef MLIR_IR_MATCHERS_H |
16 | #define MLIR_IR_MATCHERS_H |
17 | |
18 | #include "mlir/IR/BuiltinAttributes.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "mlir/IR/OpDefinition.h" |
21 | |
22 | namespace mlir { |
23 | |
24 | namespace detail { |
25 | |
26 | /// The matcher that matches a certain kind of Attribute and binds the value |
27 | /// inside the Attribute. |
28 | template < |
29 | typename AttrClass, |
30 | // Require AttrClass to be a derived class from Attribute and get its |
31 | // value type |
32 | typename ValueType = typename std::enable_if_t< |
33 | std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType, |
34 | // Require the ValueType is not void |
35 | typename = std::enable_if_t<!std::is_void<ValueType>::value>> |
36 | struct attr_value_binder { |
37 | ValueType *bind_value; |
38 | |
39 | /// Creates a matcher instance that binds the value to bv if match succeeds. |
40 | attr_value_binder(ValueType *bv) : bind_value(bv) {} |
41 | |
42 | bool match(Attribute attr) { |
43 | if (auto intAttr = llvm::dyn_cast<AttrClass>(attr)) { |
44 | *bind_value = intAttr.getValue(); |
45 | return true; |
46 | } |
47 | return false; |
48 | } |
49 | }; |
50 | |
51 | /// The matcher that matches operations that have the `ConstantLike` trait. |
52 | struct constant_op_matcher { |
53 | bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); } |
54 | }; |
55 | |
56 | /// The matcher that matches operations that have the specified op name. |
57 | struct NameOpMatcher { |
58 | NameOpMatcher(StringRef name) : name(name) {} |
59 | bool match(Operation *op) { return op->getName().getStringRef() == name; } |
60 | |
61 | StringRef name; |
62 | }; |
63 | |
64 | /// The matcher that matches operations that have the specified attribute name. |
65 | struct AttrOpMatcher { |
66 | AttrOpMatcher(StringRef attrName) : attrName(attrName) {} |
67 | bool match(Operation *op) { return op->hasAttr(name: attrName); } |
68 | |
69 | StringRef attrName; |
70 | }; |
71 | |
72 | /// The matcher that matches operations that have the `ConstantLike` trait, and |
73 | /// binds the folded attribute value. |
74 | template <typename AttrT> |
75 | struct constant_op_binder { |
76 | AttrT *bind_value; |
77 | |
78 | /// Creates a matcher instance that binds the constant attribute value to |
79 | /// bind_value if match succeeds. |
80 | constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} |
81 | /// Creates a matcher instance that doesn't bind if match succeeds. |
82 | constant_op_binder() : bind_value(nullptr) {} |
83 | |
84 | bool match(Operation *op) { |
85 | if (!op->hasTrait<OpTrait::ConstantLike>()) |
86 | return false; |
87 | |
88 | // Fold the constant to an attribute. |
89 | SmallVector<OpFoldResult, 1> foldedOp; |
90 | LogicalResult result = op->fold(/*operands=*/operands: std::nullopt, results&: foldedOp); |
91 | (void)result; |
92 | assert(succeeded(result) && "expected ConstantLike op to be foldable" ); |
93 | |
94 | if (auto attr = llvm::dyn_cast<AttrT>(foldedOp.front().get<Attribute>())) { |
95 | if (bind_value) |
96 | *bind_value = attr; |
97 | return true; |
98 | } |
99 | return false; |
100 | } |
101 | }; |
102 | |
103 | /// The matcher that matches operations that have the specified attribute |
104 | /// name, and binds the attribute value. |
105 | template <typename AttrT> |
106 | struct AttrOpBinder { |
107 | /// Creates a matcher instance that binds the attribute value to |
108 | /// bind_value if match succeeds. |
109 | AttrOpBinder(StringRef attrName, AttrT *bindValue) |
110 | : attrName(attrName), bindValue(bindValue) {} |
111 | /// Creates a matcher instance that doesn't bind if match succeeds. |
112 | AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {} |
113 | |
114 | bool match(Operation *op) { |
115 | if (auto attr = op->getAttrOfType<AttrT>(attrName)) { |
116 | if (bindValue) |
117 | *bindValue = attr; |
118 | return true; |
119 | } |
120 | return false; |
121 | } |
122 | StringRef attrName; |
123 | AttrT *bindValue; |
124 | }; |
125 | |
126 | /// The matcher that matches a constant scalar / vector splat / tensor splat |
127 | /// float Attribute or Operation and binds the constant float value. |
128 | struct constant_float_value_binder { |
129 | FloatAttr::ValueType *bind_value; |
130 | |
131 | /// Creates a matcher instance that binds the value to bv if match succeeds. |
132 | constant_float_value_binder(FloatAttr::ValueType *bv) : bind_value(bv) {} |
133 | |
134 | bool match(Attribute attr) { |
135 | attr_value_binder<FloatAttr> matcher(bind_value); |
136 | if (matcher.match(attr)) |
137 | return true; |
138 | |
139 | if (auto splatAttr = dyn_cast<SplatElementsAttr>(Val&: attr)) |
140 | return matcher.match(splatAttr.getSplatValue<Attribute>()); |
141 | |
142 | return false; |
143 | } |
144 | |
145 | bool match(Operation *op) { |
146 | Attribute attr; |
147 | if (!constant_op_binder<Attribute>(&attr).match(op)) |
148 | return false; |
149 | |
150 | Type type = op->getResult(idx: 0).getType(); |
151 | if (isa<FloatType, VectorType, RankedTensorType>(Val: type)) |
152 | return match(attr); |
153 | |
154 | return false; |
155 | } |
156 | }; |
157 | |
158 | /// The matcher that matches a given target constant scalar / vector splat / |
159 | /// tensor splat float value that fulfills a predicate. |
160 | struct constant_float_predicate_matcher { |
161 | bool (*predicate)(const APFloat &); |
162 | |
163 | bool match(Attribute attr) { |
164 | APFloat value(APFloat::Bogus()); |
165 | return constant_float_value_binder(&value).match(attr) && predicate(value); |
166 | } |
167 | |
168 | bool match(Operation *op) { |
169 | APFloat value(APFloat::Bogus()); |
170 | return constant_float_value_binder(&value).match(op) && predicate(value); |
171 | } |
172 | }; |
173 | |
174 | /// The matcher that matches a constant scalar / vector splat / tensor splat |
175 | /// integer Attribute or Operation and binds the constant integer value. |
176 | struct constant_int_value_binder { |
177 | IntegerAttr::ValueType *bind_value; |
178 | |
179 | /// Creates a matcher instance that binds the value to bv if match succeeds. |
180 | constant_int_value_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} |
181 | |
182 | bool match(Attribute attr) { |
183 | attr_value_binder<IntegerAttr> matcher(bind_value); |
184 | if (matcher.match(attr)) |
185 | return true; |
186 | |
187 | if (auto splatAttr = dyn_cast<SplatElementsAttr>(Val&: attr)) |
188 | return matcher.match(splatAttr.getSplatValue<Attribute>()); |
189 | |
190 | return false; |
191 | } |
192 | |
193 | bool match(Operation *op) { |
194 | Attribute attr; |
195 | if (!constant_op_binder<Attribute>(&attr).match(op)) |
196 | return false; |
197 | |
198 | Type type = op->getResult(idx: 0).getType(); |
199 | if (isa<IntegerType, IndexType, VectorType, RankedTensorType>(Val: type)) |
200 | return match(attr); |
201 | |
202 | return false; |
203 | } |
204 | }; |
205 | |
206 | /// The matcher that matches a given target constant scalar / vector splat / |
207 | /// tensor splat integer value that fulfills a predicate. |
208 | struct constant_int_predicate_matcher { |
209 | bool (*predicate)(const APInt &); |
210 | |
211 | bool match(Attribute attr) { |
212 | APInt value; |
213 | return constant_int_value_binder(&value).match(attr) && predicate(value); |
214 | } |
215 | |
216 | bool match(Operation *op) { |
217 | APInt value; |
218 | return constant_int_value_binder(&value).match(op) && predicate(value); |
219 | } |
220 | }; |
221 | |
222 | /// The matcher that matches a certain kind of op. |
223 | template <typename OpClass> |
224 | struct op_matcher { |
225 | bool match(Operation *op) { return isa<OpClass>(op); } |
226 | }; |
227 | |
228 | /// Trait to check whether T provides a 'match' method with type |
229 | /// `MatchTarget` (Value, Operation, or Attribute). |
230 | template <typename T, typename MatchTarget> |
231 | using has_compatible_matcher_t = |
232 | decltype(std::declval<T>().match(std::declval<MatchTarget>())); |
233 | |
234 | /// Statically switch to a Value matcher. |
235 | template <typename MatcherClass> |
236 | std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t, |
237 | MatcherClass, Value>::value, |
238 | bool> |
239 | matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { |
240 | return matcher.match(op->getOperand(idx)); |
241 | } |
242 | |
243 | /// Statically switch to an Operation matcher. |
244 | template <typename MatcherClass> |
245 | std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t, |
246 | MatcherClass, Operation *>::value, |
247 | bool> |
248 | matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { |
249 | if (auto *defOp = op->getOperand(idx).getDefiningOp()) |
250 | return matcher.match(defOp); |
251 | return false; |
252 | } |
253 | |
254 | /// Terminal matcher, always returns true. |
255 | struct AnyValueMatcher { |
256 | bool match(Value op) const { return true; } |
257 | }; |
258 | |
259 | /// Terminal matcher, always returns true. |
260 | struct AnyCapturedValueMatcher { |
261 | Value *what; |
262 | AnyCapturedValueMatcher(Value *what) : what(what) {} |
263 | bool match(Value op) const { |
264 | *what = op; |
265 | return true; |
266 | } |
267 | }; |
268 | |
269 | /// Binds to a specific value and matches it. |
270 | struct PatternMatcherValue { |
271 | PatternMatcherValue(Value val) : value(val) {} |
272 | bool match(Value val) const { return val == value; } |
273 | Value value; |
274 | }; |
275 | |
276 | template <typename TupleT, class CallbackT, std::size_t... Is> |
277 | constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, |
278 | std::index_sequence<Is...>) { |
279 | |
280 | (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)), |
281 | ...); |
282 | } |
283 | |
284 | template <typename... Tys, typename CallbackT> |
285 | constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) { |
286 | detail::enumerateImpl(tuple, std::forward<CallbackT>(callback), |
287 | std::make_index_sequence<sizeof...(Tys)>{}); |
288 | } |
289 | |
290 | /// RecursivePatternMatcher that composes. |
291 | template <typename OpType, typename... OperandMatchers> |
292 | struct RecursivePatternMatcher { |
293 | RecursivePatternMatcher(OperandMatchers... matchers) |
294 | : operandMatchers(matchers...) {} |
295 | bool match(Operation *op) { |
296 | if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers)) |
297 | return false; |
298 | bool res = true; |
299 | enumerate(operandMatchers, [&](size_t index, auto &matcher) { |
300 | res &= matchOperandOrValueAtIndex(op, index, matcher); |
301 | }); |
302 | return res; |
303 | } |
304 | std::tuple<OperandMatchers...> operandMatchers; |
305 | }; |
306 | |
307 | } // namespace detail |
308 | |
309 | /// Matches a constant foldable operation. |
310 | inline detail::constant_op_matcher m_Constant() { |
311 | return detail::constant_op_matcher(); |
312 | } |
313 | |
314 | /// Matches a named attribute operation. |
315 | inline detail::AttrOpMatcher m_Attr(StringRef attrName) { |
316 | return detail::AttrOpMatcher(attrName); |
317 | } |
318 | |
319 | /// Matches a named operation. |
320 | inline detail::NameOpMatcher m_Op(StringRef opName) { |
321 | return detail::NameOpMatcher(opName); |
322 | } |
323 | |
324 | /// Matches a value from a constant foldable operation and writes the value to |
325 | /// bind_value. |
326 | template <typename AttrT> |
327 | inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) { |
328 | return detail::constant_op_binder<AttrT>(bind_value); |
329 | } |
330 | |
331 | /// Matches a named attribute operation and writes the value to bind_value. |
332 | template <typename AttrT> |
333 | inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName, |
334 | AttrT *bindValue) { |
335 | return detail::AttrOpBinder<AttrT>(attrName, bindValue); |
336 | } |
337 | |
338 | /// Matches a constant scalar / vector splat / tensor splat float (both positive |
339 | /// and negative) zero. |
340 | inline detail::constant_float_predicate_matcher m_AnyZeroFloat() { |
341 | return {.predicate: [](const APFloat &value) { return value.isZero(); }}; |
342 | } |
343 | |
344 | /// Matches a constant scalar / vector splat / tensor splat float positive zero. |
345 | inline detail::constant_float_predicate_matcher m_PosZeroFloat() { |
346 | return {.predicate: [](const APFloat &value) { return value.isPosZero(); }}; |
347 | } |
348 | |
349 | /// Matches a constant scalar / vector splat / tensor splat float negative zero. |
350 | inline detail::constant_float_predicate_matcher m_NegZeroFloat() { |
351 | return {.predicate: [](const APFloat &value) { return value.isNegZero(); }}; |
352 | } |
353 | |
354 | /// Matches a constant scalar / vector splat / tensor splat float ones. |
355 | inline detail::constant_float_predicate_matcher m_OneFloat() { |
356 | return {.predicate: [](const APFloat &value) { |
357 | return APFloat(value.getSemantics(), 1) == value; |
358 | }}; |
359 | } |
360 | |
361 | /// Matches a constant scalar / vector splat / tensor splat float positive |
362 | /// infinity. |
363 | inline detail::constant_float_predicate_matcher m_PosInfFloat() { |
364 | return {.predicate: [](const APFloat &value) { |
365 | return !value.isNegative() && value.isInfinity(); |
366 | }}; |
367 | } |
368 | |
369 | /// Matches a constant scalar / vector splat / tensor splat float negative |
370 | /// infinity. |
371 | inline detail::constant_float_predicate_matcher m_NegInfFloat() { |
372 | return {.predicate: [](const APFloat &value) { |
373 | return value.isNegative() && value.isInfinity(); |
374 | }}; |
375 | } |
376 | |
377 | /// Matches a constant scalar / vector splat / tensor splat integer zero. |
378 | inline detail::constant_int_predicate_matcher m_Zero() { |
379 | return {.predicate: [](const APInt &value) { return 0 == value; }}; |
380 | } |
381 | |
382 | /// Matches a constant scalar / vector splat / tensor splat integer that is any |
383 | /// non-zero value. |
384 | inline detail::constant_int_predicate_matcher m_NonZero() { |
385 | return {.predicate: [](const APInt &value) { return 0 != value; }}; |
386 | } |
387 | |
388 | /// Matches a constant scalar / vector splat / tensor splat integer one. |
389 | inline detail::constant_int_predicate_matcher m_One() { |
390 | return {.predicate: [](const APInt &value) { return 1 == value; }}; |
391 | } |
392 | |
393 | /// Matches the given OpClass. |
394 | template <typename OpClass> |
395 | inline detail::op_matcher<OpClass> m_Op() { |
396 | return detail::op_matcher<OpClass>(); |
397 | } |
398 | |
399 | /// Entry point for matching a pattern over a Value. |
400 | template <typename Pattern> |
401 | inline bool matchPattern(Value value, const Pattern &pattern) { |
402 | assert(value); |
403 | // TODO: handle other cases |
404 | if (auto *op = value.getDefiningOp()) |
405 | return const_cast<Pattern &>(pattern).match(op); |
406 | return false; |
407 | } |
408 | |
409 | /// Entry point for matching a pattern over an Operation. |
410 | template <typename Pattern> |
411 | inline bool matchPattern(Operation *op, const Pattern &pattern) { |
412 | assert(op); |
413 | return const_cast<Pattern &>(pattern).match(op); |
414 | } |
415 | |
416 | /// Entry point for matching a pattern over an Attribute. Returns `false` |
417 | /// when `attr` is null. |
418 | template <typename Pattern> |
419 | inline bool matchPattern(Attribute attr, const Pattern &pattern) { |
420 | static_assert(llvm::is_detected<detail::has_compatible_matcher_t, Pattern, |
421 | Attribute>::value, |
422 | "Pattern does not support matching Attributes" ); |
423 | if (!attr) |
424 | return false; |
425 | return const_cast<Pattern &>(pattern).match(attr); |
426 | } |
427 | |
428 | /// Matches a constant holding a scalar/vector/tensor float (splat) and |
429 | /// writes the float value to bind_value. |
430 | inline detail::constant_float_value_binder |
431 | m_ConstantFloat(FloatAttr::ValueType *bind_value) { |
432 | return detail::constant_float_value_binder(bind_value); |
433 | } |
434 | |
435 | /// Matches a constant holding a scalar/vector/tensor integer (splat) and |
436 | /// writes the integer value to bind_value. |
437 | inline detail::constant_int_value_binder |
438 | m_ConstantInt(IntegerAttr::ValueType *bind_value) { |
439 | return detail::constant_int_value_binder(bind_value); |
440 | } |
441 | |
442 | template <typename OpType, typename... Matchers> |
443 | auto m_Op(Matchers... matchers) { |
444 | return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...); |
445 | } |
446 | |
447 | namespace matchers { |
448 | inline auto m_Any() { return detail::AnyValueMatcher(); } |
449 | inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); } |
450 | inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); } |
451 | } // namespace matchers |
452 | |
453 | } // namespace mlir |
454 | |
455 | #endif // MLIR_IR_MATCHERS_H |
456 | |