1//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file provides a simple and efficient mechanism for performing general
10// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11// include/llvm/IR/PatternMatch.h.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef MLIR_IR_MATCHERS_H
16#define MLIR_IR_MATCHERS_H
17
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/OpDefinition.h"
21
22namespace mlir {
23
24namespace detail {
25
26/// The matcher that matches a certain kind of Attribute and binds the value
27/// inside the Attribute.
28template <
29 typename AttrClass,
30 // Require AttrClass to be a derived class from Attribute and get its
31 // value type
32 typename ValueType = typename std::enable_if_t<
33 std::is_base_of<Attribute, AttrClass>::value, AttrClass>::ValueType,
34 // Require the ValueType is not void
35 typename = std::enable_if_t<!std::is_void<ValueType>::value>>
36struct attr_value_binder {
37 ValueType *bind_value;
38
39 /// Creates a matcher instance that binds the value to bv if match succeeds.
40 attr_value_binder(ValueType *bv) : bind_value(bv) {}
41
42 bool match(Attribute attr) {
43 if (auto intAttr = llvm::dyn_cast<AttrClass>(attr)) {
44 *bind_value = intAttr.getValue();
45 return true;
46 }
47 return false;
48 }
49};
50
51/// The matcher that matches operations that have the `ConstantLike` trait.
52struct constant_op_matcher {
53 bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
54};
55
56/// The matcher that matches operations that have the specified op name.
57struct NameOpMatcher {
58 NameOpMatcher(StringRef name) : name(name) {}
59 bool match(Operation *op) { return op->getName().getStringRef() == name; }
60
61 StringRef name;
62};
63
64/// The matcher that matches operations that have the specified attribute name.
65struct AttrOpMatcher {
66 AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
67 bool match(Operation *op) { return op->hasAttr(name: attrName); }
68
69 StringRef attrName;
70};
71
72/// The matcher that matches operations that have the `ConstantLike` trait, and
73/// binds the folded attribute value.
74template <typename AttrT>
75struct constant_op_binder {
76 AttrT *bind_value;
77
78 /// Creates a matcher instance that binds the constant attribute value to
79 /// bind_value if match succeeds.
80 constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
81 /// Creates a matcher instance that doesn't bind if match succeeds.
82 constant_op_binder() : bind_value(nullptr) {}
83
84 bool match(Operation *op) {
85 if (!op->hasTrait<OpTrait::ConstantLike>())
86 return false;
87
88 // Fold the constant to an attribute.
89 SmallVector<OpFoldResult, 1> foldedOp;
90 LogicalResult result = op->fold(/*operands=*/operands: std::nullopt, results&: foldedOp);
91 (void)result;
92 assert(succeeded(result) && "expected ConstantLike op to be foldable");
93
94 if (auto attr = llvm::dyn_cast<AttrT>(foldedOp.front().get<Attribute>())) {
95 if (bind_value)
96 *bind_value = attr;
97 return true;
98 }
99 return false;
100 }
101};
102
103/// The matcher that matches operations that have the specified attribute
104/// name, and binds the attribute value.
105template <typename AttrT>
106struct AttrOpBinder {
107 /// Creates a matcher instance that binds the attribute value to
108 /// bind_value if match succeeds.
109 AttrOpBinder(StringRef attrName, AttrT *bindValue)
110 : attrName(attrName), bindValue(bindValue) {}
111 /// Creates a matcher instance that doesn't bind if match succeeds.
112 AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {}
113
114 bool match(Operation *op) {
115 if (auto attr = op->getAttrOfType<AttrT>(attrName)) {
116 if (bindValue)
117 *bindValue = attr;
118 return true;
119 }
120 return false;
121 }
122 StringRef attrName;
123 AttrT *bindValue;
124};
125
126/// The matcher that matches a constant scalar / vector splat / tensor splat
127/// float Attribute or Operation and binds the constant float value.
128struct constant_float_value_binder {
129 FloatAttr::ValueType *bind_value;
130
131 /// Creates a matcher instance that binds the value to bv if match succeeds.
132 constant_float_value_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
133
134 bool match(Attribute attr) {
135 attr_value_binder<FloatAttr> matcher(bind_value);
136 if (matcher.match(attr))
137 return true;
138
139 if (auto splatAttr = dyn_cast<SplatElementsAttr>(Val&: attr))
140 return matcher.match(splatAttr.getSplatValue<Attribute>());
141
142 return false;
143 }
144
145 bool match(Operation *op) {
146 Attribute attr;
147 if (!constant_op_binder<Attribute>(&attr).match(op))
148 return false;
149
150 Type type = op->getResult(idx: 0).getType();
151 if (isa<FloatType, VectorType, RankedTensorType>(Val: type))
152 return match(attr);
153
154 return false;
155 }
156};
157
158/// The matcher that matches a given target constant scalar / vector splat /
159/// tensor splat float value that fulfills a predicate.
160struct constant_float_predicate_matcher {
161 bool (*predicate)(const APFloat &);
162
163 bool match(Attribute attr) {
164 APFloat value(APFloat::Bogus());
165 return constant_float_value_binder(&value).match(attr) && predicate(value);
166 }
167
168 bool match(Operation *op) {
169 APFloat value(APFloat::Bogus());
170 return constant_float_value_binder(&value).match(op) && predicate(value);
171 }
172};
173
174/// The matcher that matches a constant scalar / vector splat / tensor splat
175/// integer Attribute or Operation and binds the constant integer value.
176struct constant_int_value_binder {
177 IntegerAttr::ValueType *bind_value;
178
179 /// Creates a matcher instance that binds the value to bv if match succeeds.
180 constant_int_value_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
181
182 bool match(Attribute attr) {
183 attr_value_binder<IntegerAttr> matcher(bind_value);
184 if (matcher.match(attr))
185 return true;
186
187 if (auto splatAttr = dyn_cast<SplatElementsAttr>(Val&: attr))
188 return matcher.match(splatAttr.getSplatValue<Attribute>());
189
190 return false;
191 }
192
193 bool match(Operation *op) {
194 Attribute attr;
195 if (!constant_op_binder<Attribute>(&attr).match(op))
196 return false;
197
198 Type type = op->getResult(idx: 0).getType();
199 if (isa<IntegerType, IndexType, VectorType, RankedTensorType>(Val: type))
200 return match(attr);
201
202 return false;
203 }
204};
205
206/// The matcher that matches a given target constant scalar / vector splat /
207/// tensor splat integer value that fulfills a predicate.
208struct constant_int_predicate_matcher {
209 bool (*predicate)(const APInt &);
210
211 bool match(Attribute attr) {
212 APInt value;
213 return constant_int_value_binder(&value).match(attr) && predicate(value);
214 }
215
216 bool match(Operation *op) {
217 APInt value;
218 return constant_int_value_binder(&value).match(op) && predicate(value);
219 }
220};
221
222/// The matcher that matches a certain kind of op.
223template <typename OpClass>
224struct op_matcher {
225 bool match(Operation *op) { return isa<OpClass>(op); }
226};
227
228/// Trait to check whether T provides a 'match' method with type
229/// `MatchTarget` (Value, Operation, or Attribute).
230template <typename T, typename MatchTarget>
231using has_compatible_matcher_t =
232 decltype(std::declval<T>().match(std::declval<MatchTarget>()));
233
234/// Statically switch to a Value matcher.
235template <typename MatcherClass>
236std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
237 MatcherClass, Value>::value,
238 bool>
239matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
240 return matcher.match(op->getOperand(idx));
241}
242
243/// Statically switch to an Operation matcher.
244template <typename MatcherClass>
245std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
246 MatcherClass, Operation *>::value,
247 bool>
248matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
249 if (auto *defOp = op->getOperand(idx).getDefiningOp())
250 return matcher.match(defOp);
251 return false;
252}
253
254/// Terminal matcher, always returns true.
255struct AnyValueMatcher {
256 bool match(Value op) const { return true; }
257};
258
259/// Terminal matcher, always returns true.
260struct AnyCapturedValueMatcher {
261 Value *what;
262 AnyCapturedValueMatcher(Value *what) : what(what) {}
263 bool match(Value op) const {
264 *what = op;
265 return true;
266 }
267};
268
269/// Binds to a specific value and matches it.
270struct PatternMatcherValue {
271 PatternMatcherValue(Value val) : value(val) {}
272 bool match(Value val) const { return val == value; }
273 Value value;
274};
275
276template <typename TupleT, class CallbackT, std::size_t... Is>
277constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
278 std::index_sequence<Is...>) {
279
280 (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
281 ...);
282}
283
284template <typename... Tys, typename CallbackT>
285constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
286 detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
287 std::make_index_sequence<sizeof...(Tys)>{});
288}
289
290/// RecursivePatternMatcher that composes.
291template <typename OpType, typename... OperandMatchers>
292struct RecursivePatternMatcher {
293 RecursivePatternMatcher(OperandMatchers... matchers)
294 : operandMatchers(matchers...) {}
295 bool match(Operation *op) {
296 if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
297 return false;
298 bool res = true;
299 enumerate(operandMatchers, [&](size_t index, auto &matcher) {
300 res &= matchOperandOrValueAtIndex(op, index, matcher);
301 });
302 return res;
303 }
304 std::tuple<OperandMatchers...> operandMatchers;
305};
306
307} // namespace detail
308
309/// Matches a constant foldable operation.
310inline detail::constant_op_matcher m_Constant() {
311 return detail::constant_op_matcher();
312}
313
314/// Matches a named attribute operation.
315inline detail::AttrOpMatcher m_Attr(StringRef attrName) {
316 return detail::AttrOpMatcher(attrName);
317}
318
319/// Matches a named operation.
320inline detail::NameOpMatcher m_Op(StringRef opName) {
321 return detail::NameOpMatcher(opName);
322}
323
324/// Matches a value from a constant foldable operation and writes the value to
325/// bind_value.
326template <typename AttrT>
327inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
328 return detail::constant_op_binder<AttrT>(bind_value);
329}
330
331/// Matches a named attribute operation and writes the value to bind_value.
332template <typename AttrT>
333inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName,
334 AttrT *bindValue) {
335 return detail::AttrOpBinder<AttrT>(attrName, bindValue);
336}
337
338/// Matches a constant scalar / vector splat / tensor splat float (both positive
339/// and negative) zero.
340inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {
341 return {.predicate: [](const APFloat &value) { return value.isZero(); }};
342}
343
344/// Matches a constant scalar / vector splat / tensor splat float positive zero.
345inline detail::constant_float_predicate_matcher m_PosZeroFloat() {
346 return {.predicate: [](const APFloat &value) { return value.isPosZero(); }};
347}
348
349/// Matches a constant scalar / vector splat / tensor splat float negative zero.
350inline detail::constant_float_predicate_matcher m_NegZeroFloat() {
351 return {.predicate: [](const APFloat &value) { return value.isNegZero(); }};
352}
353
354/// Matches a constant scalar / vector splat / tensor splat float ones.
355inline detail::constant_float_predicate_matcher m_OneFloat() {
356 return {.predicate: [](const APFloat &value) {
357 return APFloat(value.getSemantics(), 1) == value;
358 }};
359}
360
361/// Matches a constant scalar / vector splat / tensor splat float positive
362/// infinity.
363inline detail::constant_float_predicate_matcher m_PosInfFloat() {
364 return {.predicate: [](const APFloat &value) {
365 return !value.isNegative() && value.isInfinity();
366 }};
367}
368
369/// Matches a constant scalar / vector splat / tensor splat float negative
370/// infinity.
371inline detail::constant_float_predicate_matcher m_NegInfFloat() {
372 return {.predicate: [](const APFloat &value) {
373 return value.isNegative() && value.isInfinity();
374 }};
375}
376
377/// Matches a constant scalar / vector splat / tensor splat integer zero.
378inline detail::constant_int_predicate_matcher m_Zero() {
379 return {.predicate: [](const APInt &value) { return 0 == value; }};
380}
381
382/// Matches a constant scalar / vector splat / tensor splat integer that is any
383/// non-zero value.
384inline detail::constant_int_predicate_matcher m_NonZero() {
385 return {.predicate: [](const APInt &value) { return 0 != value; }};
386}
387
388/// Matches a constant scalar / vector splat / tensor splat integer one.
389inline detail::constant_int_predicate_matcher m_One() {
390 return {.predicate: [](const APInt &value) { return 1 == value; }};
391}
392
393/// Matches the given OpClass.
394template <typename OpClass>
395inline detail::op_matcher<OpClass> m_Op() {
396 return detail::op_matcher<OpClass>();
397}
398
399/// Entry point for matching a pattern over a Value.
400template <typename Pattern>
401inline bool matchPattern(Value value, const Pattern &pattern) {
402 assert(value);
403 // TODO: handle other cases
404 if (auto *op = value.getDefiningOp())
405 return const_cast<Pattern &>(pattern).match(op);
406 return false;
407}
408
409/// Entry point for matching a pattern over an Operation.
410template <typename Pattern>
411inline bool matchPattern(Operation *op, const Pattern &pattern) {
412 assert(op);
413 return const_cast<Pattern &>(pattern).match(op);
414}
415
416/// Entry point for matching a pattern over an Attribute. Returns `false`
417/// when `attr` is null.
418template <typename Pattern>
419inline bool matchPattern(Attribute attr, const Pattern &pattern) {
420 static_assert(llvm::is_detected<detail::has_compatible_matcher_t, Pattern,
421 Attribute>::value,
422 "Pattern does not support matching Attributes");
423 if (!attr)
424 return false;
425 return const_cast<Pattern &>(pattern).match(attr);
426}
427
428/// Matches a constant holding a scalar/vector/tensor float (splat) and
429/// writes the float value to bind_value.
430inline detail::constant_float_value_binder
431m_ConstantFloat(FloatAttr::ValueType *bind_value) {
432 return detail::constant_float_value_binder(bind_value);
433}
434
435/// Matches a constant holding a scalar/vector/tensor integer (splat) and
436/// writes the integer value to bind_value.
437inline detail::constant_int_value_binder
438m_ConstantInt(IntegerAttr::ValueType *bind_value) {
439 return detail::constant_int_value_binder(bind_value);
440}
441
442template <typename OpType, typename... Matchers>
443auto m_Op(Matchers... matchers) {
444 return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
445}
446
447namespace matchers {
448inline auto m_Any() { return detail::AnyValueMatcher(); }
449inline auto m_Any(Value *val) { return detail::AnyCapturedValueMatcher(val); }
450inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
451} // namespace matchers
452
453} // namespace mlir
454
455#endif // MLIR_IR_MATCHERS_H
456

source code of mlir/include/mlir/IR/Matchers.h