1 | //==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// \file |
9 | /// Contains matchers for matching SelectionDAG nodes and values. |
10 | /// |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef LLVM_CODEGEN_SDPATTERNMATCH_H |
14 | #define LLVM_CODEGEN_SDPATTERNMATCH_H |
15 | |
16 | #include "llvm/ADT/APInt.h" |
17 | #include "llvm/ADT/STLExtras.h" |
18 | #include "llvm/CodeGen/SelectionDAG.h" |
19 | #include "llvm/CodeGen/SelectionDAGNodes.h" |
20 | #include "llvm/CodeGen/TargetLowering.h" |
21 | |
22 | namespace llvm { |
23 | namespace SDPatternMatch { |
24 | |
25 | /// MatchContext can repurpose existing patterns to behave differently under |
26 | /// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes |
27 | /// in normal circumstances, but matches VP_ADD nodes under a custom |
28 | /// VPMatchContext. This design is meant to facilitate code / pattern reusing. |
29 | class BasicMatchContext { |
30 | const SelectionDAG *DAG; |
31 | const TargetLowering *TLI; |
32 | |
33 | public: |
34 | explicit BasicMatchContext(const SelectionDAG *DAG) |
35 | : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {} |
36 | |
37 | explicit BasicMatchContext(const TargetLowering *TLI) |
38 | : DAG(nullptr), TLI(TLI) {} |
39 | |
40 | // A valid MatchContext has to implement the following functions. |
41 | |
42 | const SelectionDAG *getDAG() const { return DAG; } |
43 | |
44 | const TargetLowering *getTLI() const { return TLI; } |
45 | |
46 | /// Return true if N effectively has opcode Opcode. |
47 | bool match(SDValue N, unsigned Opcode) const { |
48 | return N->getOpcode() == Opcode; |
49 | } |
50 | }; |
51 | |
52 | template <typename Pattern, typename MatchContext> |
53 | [[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx, |
54 | Pattern &&P) { |
55 | return P.match(Ctx, N); |
56 | } |
57 | |
58 | template <typename Pattern, typename MatchContext> |
59 | [[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx, |
60 | Pattern &&P) { |
61 | return sd_context_match(SDValue(N, 0), Ctx, P); |
62 | } |
63 | |
64 | template <typename Pattern> |
65 | [[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) { |
66 | return sd_context_match(N, BasicMatchContext(DAG), P); |
67 | } |
68 | |
69 | template <typename Pattern> |
70 | [[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) { |
71 | return sd_context_match(N, BasicMatchContext(DAG), P); |
72 | } |
73 | |
74 | template <typename Pattern> |
75 | [[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) { |
76 | return sd_match(N, nullptr, P); |
77 | } |
78 | |
79 | template <typename Pattern> |
80 | [[nodiscard]] bool sd_match(SDValue N, Pattern &&P) { |
81 | return sd_match(N, nullptr, P); |
82 | } |
83 | |
84 | // === Utilities === |
85 | struct Value_match { |
86 | SDValue MatchVal; |
87 | |
88 | Value_match() = default; |
89 | |
90 | explicit Value_match(SDValue Match) : MatchVal(Match) {} |
91 | |
92 | template <typename MatchContext> bool match(const MatchContext &, SDValue N) { |
93 | if (MatchVal) |
94 | return MatchVal == N; |
95 | return N.getNode(); |
96 | } |
97 | }; |
98 | |
99 | /// Match any valid SDValue. |
100 | inline Value_match m_Value() { return Value_match(); } |
101 | |
102 | inline Value_match m_Specific(SDValue N) { |
103 | assert(N); |
104 | return Value_match(N); |
105 | } |
106 | |
107 | struct DeferredValue_match { |
108 | SDValue &MatchVal; |
109 | |
110 | explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {} |
111 | |
112 | template <typename MatchContext> bool match(const MatchContext &, SDValue N) { |
113 | return N == MatchVal; |
114 | } |
115 | }; |
116 | |
117 | /// Similar to m_Specific, but the specific value to match is determined by |
118 | /// another sub-pattern in the same sd_match() expression. For instance, |
119 | /// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since |
120 | /// `X` is not initialized at the time it got copied into `m_Specific`. Instead, |
121 | /// we should use `m_Add(m_Value(X), m_Deferred(X))`. |
122 | inline DeferredValue_match m_Deferred(SDValue &V) { |
123 | return DeferredValue_match(V); |
124 | } |
125 | |
126 | struct Opcode_match { |
127 | unsigned Opcode; |
128 | |
129 | explicit Opcode_match(unsigned Opc) : Opcode(Opc) {} |
130 | |
131 | template <typename MatchContext> |
132 | bool match(const MatchContext &Ctx, SDValue N) { |
133 | return Ctx.match(N, Opcode); |
134 | } |
135 | }; |
136 | |
137 | inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); } |
138 | |
139 | template <unsigned NumUses, typename Pattern> struct NUses_match { |
140 | Pattern P; |
141 | |
142 | explicit NUses_match(const Pattern &P) : P(P) {} |
143 | |
144 | template <typename MatchContext> |
145 | bool match(const MatchContext &Ctx, SDValue N) { |
146 | // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces |
147 | // multiple results, hence we check the subsequent pattern here before |
148 | // checking the number of value users. |
149 | return P.match(Ctx, N) && N->hasNUsesOfValue(NUses: NumUses, Value: N.getResNo()); |
150 | } |
151 | }; |
152 | |
153 | template <typename Pattern> |
154 | inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) { |
155 | return NUses_match<1, Pattern>(P); |
156 | } |
157 | template <unsigned N, typename Pattern> |
158 | inline NUses_match<N, Pattern> m_NUses(const Pattern &P) { |
159 | return NUses_match<N, Pattern>(P); |
160 | } |
161 | |
162 | inline NUses_match<1, Value_match> m_OneUse() { |
163 | return NUses_match<1, Value_match>(m_Value()); |
164 | } |
165 | template <unsigned N> inline NUses_match<N, Value_match> m_NUses() { |
166 | return NUses_match<N, Value_match>(m_Value()); |
167 | } |
168 | |
169 | struct Value_bind { |
170 | SDValue &BindVal; |
171 | |
172 | explicit Value_bind(SDValue &N) : BindVal(N) {} |
173 | |
174 | template <typename MatchContext> bool match(const MatchContext &, SDValue N) { |
175 | BindVal = N; |
176 | return true; |
177 | } |
178 | }; |
179 | |
180 | inline Value_bind m_Value(SDValue &N) { return Value_bind(N); } |
181 | |
182 | template <typename Pattern, typename PredFuncT> struct TLI_pred_match { |
183 | Pattern P; |
184 | PredFuncT PredFunc; |
185 | |
186 | TLI_pred_match(const PredFuncT &Pred, const Pattern &P) |
187 | : P(P), PredFunc(Pred) {} |
188 | |
189 | template <typename MatchContext> |
190 | bool match(const MatchContext &Ctx, SDValue N) { |
191 | assert(Ctx.getTLI() && "TargetLowering is required for this pattern." ); |
192 | return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N); |
193 | } |
194 | }; |
195 | |
196 | // Explicit deduction guide. |
197 | template <typename PredFuncT, typename Pattern> |
198 | TLI_pred_match(const PredFuncT &Pred, const Pattern &P) |
199 | -> TLI_pred_match<Pattern, PredFuncT>; |
200 | |
201 | /// Match legal SDNodes based on the information provided by TargetLowering. |
202 | template <typename Pattern> inline auto m_LegalOp(const Pattern &P) { |
203 | return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) { |
204 | return TLI.isOperationLegal(Op: N->getOpcode(), |
205 | VT: N.getValueType()); |
206 | }, |
207 | P}; |
208 | } |
209 | |
210 | /// Switch to a different MatchContext for subsequent patterns. |
211 | template <typename NewMatchContext, typename Pattern> struct SwitchContext { |
212 | const NewMatchContext &Ctx; |
213 | Pattern P; |
214 | |
215 | template <typename OrigMatchContext> |
216 | bool match(const OrigMatchContext &, SDValue N) { |
217 | return P.match(Ctx, N); |
218 | } |
219 | }; |
220 | |
221 | template <typename MatchContext, typename Pattern> |
222 | inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx, |
223 | Pattern &&P) { |
224 | return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)}; |
225 | } |
226 | |
227 | // === Value type === |
228 | struct ValueType_bind { |
229 | EVT &BindVT; |
230 | |
231 | explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {} |
232 | |
233 | template <typename MatchContext> bool match(const MatchContext &, SDValue N) { |
234 | BindVT = N.getValueType(); |
235 | return true; |
236 | } |
237 | }; |
238 | |
239 | /// Retreive the ValueType of the current SDValue. |
240 | inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); } |
241 | |
242 | template <typename Pattern, typename PredFuncT> struct ValueType_match { |
243 | PredFuncT PredFunc; |
244 | Pattern P; |
245 | |
246 | ValueType_match(const PredFuncT &Pred, const Pattern &P) |
247 | : PredFunc(Pred), P(P) {} |
248 | |
249 | template <typename MatchContext> |
250 | bool match(const MatchContext &Ctx, SDValue N) { |
251 | return PredFunc(N.getValueType()) && P.match(Ctx, N); |
252 | } |
253 | }; |
254 | |
255 | // Explicit deduction guide. |
256 | template <typename PredFuncT, typename Pattern> |
257 | ValueType_match(const PredFuncT &Pred, const Pattern &P) |
258 | -> ValueType_match<Pattern, PredFuncT>; |
259 | |
260 | /// Match a specific ValueType. |
261 | template <typename Pattern> |
262 | inline auto m_SpecificVT(EVT RefVT, const Pattern &P) { |
263 | return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P}; |
264 | } |
265 | inline auto m_SpecificVT(EVT RefVT) { |
266 | return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()}; |
267 | } |
268 | |
269 | inline auto m_Glue() { return m_SpecificVT(MVT::Glue); } |
270 | inline auto m_OtherVT() { return m_SpecificVT(MVT::Other); } |
271 | |
272 | /// Match any integer ValueTypes. |
273 | template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) { |
274 | return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P}; |
275 | } |
276 | inline auto m_IntegerVT() { |
277 | return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()}; |
278 | } |
279 | |
280 | /// Match any floating point ValueTypes. |
281 | template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) { |
282 | return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P}; |
283 | } |
284 | inline auto m_FloatingPointVT() { |
285 | return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, |
286 | m_Value()}; |
287 | } |
288 | |
289 | /// Match any vector ValueTypes. |
290 | template <typename Pattern> inline auto m_VectorVT(const Pattern &P) { |
291 | return ValueType_match{[](EVT VT) { return VT.isVector(); }, P}; |
292 | } |
293 | inline auto m_VectorVT() { |
294 | return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()}; |
295 | } |
296 | |
297 | /// Match fixed-length vector ValueTypes. |
298 | template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) { |
299 | return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P}; |
300 | } |
301 | inline auto m_FixedVectorVT() { |
302 | return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, |
303 | m_Value()}; |
304 | } |
305 | |
306 | /// Match scalable vector ValueTypes. |
307 | template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) { |
308 | return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P}; |
309 | } |
310 | inline auto m_ScalableVectorVT() { |
311 | return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, |
312 | m_Value()}; |
313 | } |
314 | |
315 | /// Match legal ValueTypes based on the information provided by TargetLowering. |
316 | template <typename Pattern> inline auto m_LegalType(const Pattern &P) { |
317 | return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) { |
318 | return TLI.isTypeLegal(VT: N.getValueType()); |
319 | }, |
320 | P}; |
321 | } |
322 | |
323 | // === Patterns combinators === |
324 | template <typename... Preds> struct And { |
325 | template <typename MatchContext> bool match(const MatchContext &, SDValue N) { |
326 | return true; |
327 | } |
328 | }; |
329 | |
330 | template <typename Pred, typename... Preds> |
331 | struct And<Pred, Preds...> : And<Preds...> { |
332 | Pred P; |
333 | And(Pred &&p, Preds &&...preds) |
334 | : And<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) { |
335 | } |
336 | |
337 | template <typename MatchContext> |
338 | bool match(const MatchContext &Ctx, SDValue N) { |
339 | return P.match(Ctx, N) && And<Preds...>::match(Ctx, N); |
340 | } |
341 | }; |
342 | |
343 | template <typename... Preds> struct Or { |
344 | template <typename MatchContext> bool match(const MatchContext &, SDValue N) { |
345 | return false; |
346 | } |
347 | }; |
348 | |
349 | template <typename Pred, typename... Preds> |
350 | struct Or<Pred, Preds...> : Or<Preds...> { |
351 | Pred P; |
352 | Or(Pred &&p, Preds &&...preds) |
353 | : Or<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {} |
354 | |
355 | template <typename MatchContext> |
356 | bool match(const MatchContext &Ctx, SDValue N) { |
357 | return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N); |
358 | } |
359 | }; |
360 | |
361 | template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) { |
362 | return And<Preds...>(std::forward<Preds>(preds)...); |
363 | } |
364 | |
365 | template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) { |
366 | return Or<Preds...>(std::forward<Preds>(preds)...); |
367 | } |
368 | |
369 | // === Generic node matching === |
370 | template <unsigned OpIdx, typename... OpndPreds> struct Operands_match { |
371 | template <typename MatchContext> |
372 | bool match(const MatchContext &Ctx, SDValue N) { |
373 | // Returns false if there are more operands than predicates; |
374 | return N->getNumOperands() == OpIdx; |
375 | } |
376 | }; |
377 | |
378 | template <unsigned OpIdx, typename OpndPred, typename... OpndPreds> |
379 | struct Operands_match<OpIdx, OpndPred, OpndPreds...> |
380 | : Operands_match<OpIdx + 1, OpndPreds...> { |
381 | OpndPred P; |
382 | |
383 | Operands_match(OpndPred &&p, OpndPreds &&...preds) |
384 | : Operands_match<OpIdx + 1, OpndPreds...>( |
385 | std::forward<OpndPreds>(preds)...), |
386 | P(std::forward<OpndPred>(p)) {} |
387 | |
388 | template <typename MatchContext> |
389 | bool match(const MatchContext &Ctx, SDValue N) { |
390 | if (OpIdx < N->getNumOperands()) |
391 | return P.match(Ctx, N->getOperand(Num: OpIdx)) && |
392 | Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N); |
393 | |
394 | // This is the case where there are more predicates than operands. |
395 | return false; |
396 | } |
397 | }; |
398 | |
399 | template <typename... OpndPreds> |
400 | auto m_Node(unsigned Opcode, OpndPreds &&...preds) { |
401 | return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>( |
402 | std::forward<OpndPreds>(preds)...)); |
403 | } |
404 | |
405 | /// Provide number of operands that are not chain or glue, as well as the first |
406 | /// index of such operand. |
407 | template <bool ExcludeChain> struct EffectiveOperands { |
408 | unsigned Size = 0; |
409 | unsigned FirstIndex = 0; |
410 | |
411 | explicit EffectiveOperands(SDValue N) { |
412 | const unsigned TotalNumOps = N->getNumOperands(); |
413 | FirstIndex = TotalNumOps; |
414 | for (unsigned I = 0; I < TotalNumOps; ++I) { |
415 | // Count the number of non-chain and non-glue nodes (we ignore chain |
416 | // and glue by default) and retreive the operand index offset. |
417 | EVT VT = N->getOperand(Num: I).getValueType(); |
418 | if (VT != MVT::Glue && VT != MVT::Other) { |
419 | ++Size; |
420 | if (FirstIndex == TotalNumOps) |
421 | FirstIndex = I; |
422 | } |
423 | } |
424 | } |
425 | }; |
426 | |
427 | template <> struct EffectiveOperands<false> { |
428 | unsigned Size = 0; |
429 | unsigned FirstIndex = 0; |
430 | |
431 | explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {} |
432 | }; |
433 | |
434 | // === Binary operations === |
435 | template <typename LHS_P, typename RHS_P, bool Commutable = false, |
436 | bool ExcludeChain = false> |
437 | struct BinaryOpc_match { |
438 | unsigned Opcode; |
439 | LHS_P LHS; |
440 | RHS_P RHS; |
441 | |
442 | BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R) |
443 | : Opcode(Opc), LHS(L), RHS(R) {} |
444 | |
445 | template <typename MatchContext> |
446 | bool match(const MatchContext &Ctx, SDValue N) { |
447 | if (sd_context_match(N, Ctx, m_Opc(Opcode))) { |
448 | EffectiveOperands<ExcludeChain> EO(N); |
449 | assert(EO.Size == 2); |
450 | return (LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex)) && |
451 | RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1))) || |
452 | (Commutable && LHS.match(Ctx, N->getOperand(Num: EO.FirstIndex + 1)) && |
453 | RHS.match(Ctx, N->getOperand(Num: EO.FirstIndex))); |
454 | } |
455 | |
456 | return false; |
457 | } |
458 | }; |
459 | |
460 | template <typename LHS, typename RHS> |
461 | inline BinaryOpc_match<LHS, RHS, false> m_BinOp(unsigned Opc, const LHS &L, |
462 | const RHS &R) { |
463 | return BinaryOpc_match<LHS, RHS, false>(Opc, L, R); |
464 | } |
465 | template <typename LHS, typename RHS> |
466 | inline BinaryOpc_match<LHS, RHS, true> m_c_BinOp(unsigned Opc, const LHS &L, |
467 | const RHS &R) { |
468 | return BinaryOpc_match<LHS, RHS, true>(Opc, L, R); |
469 | } |
470 | |
471 | template <typename LHS, typename RHS> |
472 | inline BinaryOpc_match<LHS, RHS, false, true> |
473 | m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) { |
474 | return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R); |
475 | } |
476 | template <typename LHS, typename RHS> |
477 | inline BinaryOpc_match<LHS, RHS, true, true> |
478 | m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) { |
479 | return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R); |
480 | } |
481 | |
482 | // Common binary operations |
483 | template <typename LHS, typename RHS> |
484 | inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) { |
485 | return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R); |
486 | } |
487 | |
488 | template <typename LHS, typename RHS> |
489 | inline BinaryOpc_match<LHS, RHS, false> m_Sub(const LHS &L, const RHS &R) { |
490 | return BinaryOpc_match<LHS, RHS, false>(ISD::SUB, L, R); |
491 | } |
492 | |
493 | template <typename LHS, typename RHS> |
494 | inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) { |
495 | return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R); |
496 | } |
497 | |
498 | template <typename LHS, typename RHS> |
499 | inline BinaryOpc_match<LHS, RHS, true> m_And(const LHS &L, const RHS &R) { |
500 | return BinaryOpc_match<LHS, RHS, true>(ISD::AND, L, R); |
501 | } |
502 | |
503 | template <typename LHS, typename RHS> |
504 | inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) { |
505 | return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R); |
506 | } |
507 | |
508 | template <typename LHS, typename RHS> |
509 | inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) { |
510 | return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R); |
511 | } |
512 | |
513 | template <typename LHS, typename RHS> |
514 | inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) { |
515 | return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R); |
516 | } |
517 | |
518 | template <typename LHS, typename RHS> |
519 | inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) { |
520 | return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R); |
521 | } |
522 | |
523 | template <typename LHS, typename RHS> |
524 | inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) { |
525 | return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R); |
526 | } |
527 | |
528 | template <typename LHS, typename RHS> |
529 | inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) { |
530 | return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R); |
531 | } |
532 | |
533 | template <typename LHS, typename RHS> |
534 | inline BinaryOpc_match<LHS, RHS, false> m_UDiv(const LHS &L, const RHS &R) { |
535 | return BinaryOpc_match<LHS, RHS, false>(ISD::UDIV, L, R); |
536 | } |
537 | template <typename LHS, typename RHS> |
538 | inline BinaryOpc_match<LHS, RHS, false> m_SDiv(const LHS &L, const RHS &R) { |
539 | return BinaryOpc_match<LHS, RHS, false>(ISD::SDIV, L, R); |
540 | } |
541 | |
542 | template <typename LHS, typename RHS> |
543 | inline BinaryOpc_match<LHS, RHS, false> m_URem(const LHS &L, const RHS &R) { |
544 | return BinaryOpc_match<LHS, RHS, false>(ISD::UREM, L, R); |
545 | } |
546 | template <typename LHS, typename RHS> |
547 | inline BinaryOpc_match<LHS, RHS, false> m_SRem(const LHS &L, const RHS &R) { |
548 | return BinaryOpc_match<LHS, RHS, false>(ISD::SREM, L, R); |
549 | } |
550 | |
551 | template <typename LHS, typename RHS> |
552 | inline BinaryOpc_match<LHS, RHS, false> m_Shl(const LHS &L, const RHS &R) { |
553 | return BinaryOpc_match<LHS, RHS, false>(ISD::SHL, L, R); |
554 | } |
555 | |
556 | template <typename LHS, typename RHS> |
557 | inline BinaryOpc_match<LHS, RHS, false> m_Sra(const LHS &L, const RHS &R) { |
558 | return BinaryOpc_match<LHS, RHS, false>(ISD::SRA, L, R); |
559 | } |
560 | template <typename LHS, typename RHS> |
561 | inline BinaryOpc_match<LHS, RHS, false> m_Srl(const LHS &L, const RHS &R) { |
562 | return BinaryOpc_match<LHS, RHS, false>(ISD::SRL, L, R); |
563 | } |
564 | |
565 | template <typename LHS, typename RHS> |
566 | inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) { |
567 | return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R); |
568 | } |
569 | |
570 | template <typename LHS, typename RHS> |
571 | inline BinaryOpc_match<LHS, RHS, false> m_FSub(const LHS &L, const RHS &R) { |
572 | return BinaryOpc_match<LHS, RHS, false>(ISD::FSUB, L, R); |
573 | } |
574 | |
575 | template <typename LHS, typename RHS> |
576 | inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) { |
577 | return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R); |
578 | } |
579 | |
580 | template <typename LHS, typename RHS> |
581 | inline BinaryOpc_match<LHS, RHS, false> m_FDiv(const LHS &L, const RHS &R) { |
582 | return BinaryOpc_match<LHS, RHS, false>(ISD::FDIV, L, R); |
583 | } |
584 | |
585 | template <typename LHS, typename RHS> |
586 | inline BinaryOpc_match<LHS, RHS, false> m_FRem(const LHS &L, const RHS &R) { |
587 | return BinaryOpc_match<LHS, RHS, false>(ISD::FREM, L, R); |
588 | } |
589 | |
590 | // === Unary operations === |
591 | template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match { |
592 | unsigned Opcode; |
593 | Opnd_P Opnd; |
594 | |
595 | UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {} |
596 | |
597 | template <typename MatchContext> |
598 | bool match(const MatchContext &Ctx, SDValue N) { |
599 | if (sd_context_match(N, Ctx, m_Opc(Opcode))) { |
600 | EffectiveOperands<ExcludeChain> EO(N); |
601 | assert(EO.Size == 1); |
602 | return Opnd.match(Ctx, N->getOperand(Num: EO.FirstIndex)); |
603 | } |
604 | |
605 | return false; |
606 | } |
607 | }; |
608 | |
609 | template <typename Opnd> |
610 | inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) { |
611 | return UnaryOpc_match<Opnd>(Opc, Op); |
612 | } |
613 | template <typename Opnd> |
614 | inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc, |
615 | const Opnd &Op) { |
616 | return UnaryOpc_match<Opnd, true>(Opc, Op); |
617 | } |
618 | |
619 | template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) { |
620 | return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op); |
621 | } |
622 | |
623 | template <typename Opnd> inline UnaryOpc_match<Opnd> m_SExt(const Opnd &Op) { |
624 | return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op); |
625 | } |
626 | |
627 | template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) { |
628 | return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op); |
629 | } |
630 | |
631 | template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) { |
632 | return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op); |
633 | } |
634 | |
635 | /// Match a zext or identity |
636 | /// Allows to peek through optional extensions |
637 | template <typename Opnd> |
638 | inline Or<UnaryOpc_match<Opnd>, Opnd> m_ZExtOrSelf(Opnd &&Op) { |
639 | return Or<UnaryOpc_match<Opnd>, Opnd>(m_ZExt(std::forward<Opnd>(Op)), |
640 | std::forward<Opnd>(Op)); |
641 | } |
642 | |
643 | /// Match a sext or identity |
644 | /// Allows to peek through optional extensions |
645 | template <typename Opnd> |
646 | inline Or<UnaryOpc_match<Opnd>, Opnd> m_SExtOrSelf(Opnd &&Op) { |
647 | return Or<UnaryOpc_match<Opnd>, Opnd>(m_SExt(std::forward<Opnd>(Op)), |
648 | std::forward<Opnd>(Op)); |
649 | } |
650 | |
651 | /// Match a aext or identity |
652 | /// Allows to peek through optional extensions |
653 | template <typename Opnd> |
654 | inline Or<UnaryOpc_match<Opnd>, Opnd> m_AExtOrSelf(Opnd &&Op) { |
655 | return Or<UnaryOpc_match<Opnd>, Opnd>(m_AnyExt(std::forward<Opnd>(Op)), |
656 | std::forward<Opnd>(Op)); |
657 | } |
658 | |
659 | /// Match a trunc or identity |
660 | /// Allows to peek through optional truncations |
661 | template <typename Opnd> |
662 | inline Or<UnaryOpc_match<Opnd>, Opnd> m_TruncOrSelf(Opnd &&Op) { |
663 | return Or<UnaryOpc_match<Opnd>, Opnd>(m_Trunc(std::forward<Opnd>(Op)), |
664 | std::forward<Opnd>(Op)); |
665 | } |
666 | |
667 | // === Constants === |
668 | struct ConstantInt_match { |
669 | APInt *BindVal; |
670 | |
671 | explicit ConstantInt_match(APInt *V) : BindVal(V) {} |
672 | |
673 | template <typename MatchContext> bool match(const MatchContext &, SDValue N) { |
674 | // The logics here are similar to that in |
675 | // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also |
676 | // treats GlobalAddressSDNode as a constant, which is difficult to turn into |
677 | // APInt. |
678 | if (auto *C = dyn_cast_or_null<ConstantSDNode>(Val: N.getNode())) { |
679 | if (BindVal) |
680 | *BindVal = C->getAPIntValue(); |
681 | return true; |
682 | } |
683 | |
684 | APInt Discard; |
685 | return ISD::isConstantSplatVector(N: N.getNode(), |
686 | SplatValue&: BindVal ? *BindVal : Discard); |
687 | } |
688 | }; |
689 | /// Match any interger constants or splat of an integer constant. |
690 | inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); } |
691 | /// Match any interger constants or splat of an integer constant; return the |
692 | /// specific constant or constant splat value. |
693 | inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); } |
694 | |
695 | struct SpecificInt_match { |
696 | APInt IntVal; |
697 | |
698 | explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {} |
699 | |
700 | template <typename MatchContext> |
701 | bool match(const MatchContext &Ctx, SDValue N) { |
702 | APInt ConstInt; |
703 | if (sd_context_match(N, Ctx, m_ConstInt(V&: ConstInt))) |
704 | return APInt::isSameValue(I1: IntVal, I2: ConstInt); |
705 | return false; |
706 | } |
707 | }; |
708 | |
709 | /// Match a specific integer constant or constant splat value. |
710 | inline SpecificInt_match m_SpecificInt(APInt V) { |
711 | return SpecificInt_match(std::move(V)); |
712 | } |
713 | inline SpecificInt_match m_SpecificInt(uint64_t V) { |
714 | return SpecificInt_match(APInt(64, V)); |
715 | } |
716 | |
717 | inline SpecificInt_match m_Zero() { return m_SpecificInt(V: 0U); } |
718 | inline SpecificInt_match m_One() { return m_SpecificInt(V: 1U); } |
719 | inline SpecificInt_match m_AllOnes() { return m_SpecificInt(V: ~0U); } |
720 | |
721 | /// Match true boolean value based on the information provided by |
722 | /// TargetLowering. |
723 | inline auto m_True() { |
724 | return TLI_pred_match{ |
725 | [](const TargetLowering &TLI, SDValue N) { |
726 | APInt ConstVal; |
727 | if (sd_match(N, P: m_ConstInt(V&: ConstVal))) |
728 | switch (TLI.getBooleanContents(Type: N.getValueType())) { |
729 | case TargetLowering::ZeroOrOneBooleanContent: |
730 | return ConstVal.isOne(); |
731 | case TargetLowering::ZeroOrNegativeOneBooleanContent: |
732 | return ConstVal.isAllOnes(); |
733 | case TargetLowering::UndefinedBooleanContent: |
734 | return (ConstVal & 0x01) == 1; |
735 | } |
736 | |
737 | return false; |
738 | }, |
739 | m_Value()}; |
740 | } |
741 | /// Match false boolean value based on the information provided by |
742 | /// TargetLowering. |
743 | inline auto m_False() { |
744 | return TLI_pred_match{ |
745 | [](const TargetLowering &TLI, SDValue N) { |
746 | APInt ConstVal; |
747 | if (sd_match(N, P: m_ConstInt(V&: ConstVal))) |
748 | switch (TLI.getBooleanContents(Type: N.getValueType())) { |
749 | case TargetLowering::ZeroOrOneBooleanContent: |
750 | case TargetLowering::ZeroOrNegativeOneBooleanContent: |
751 | return ConstVal.isZero(); |
752 | case TargetLowering::UndefinedBooleanContent: |
753 | return (ConstVal & 0x01) == 0; |
754 | } |
755 | |
756 | return false; |
757 | }, |
758 | m_Value()}; |
759 | } |
760 | |
761 | /// Match a negate as a sub(0, v) |
762 | template <typename ValTy> |
763 | inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) { |
764 | return m_Sub(m_Zero(), V); |
765 | } |
766 | |
767 | /// Match a Not as a xor(v, -1) or xor(-1, v) |
768 | template <typename ValTy> |
769 | inline BinaryOpc_match<ValTy, SpecificInt_match, true> m_Not(const ValTy &V) { |
770 | return m_Xor(V, m_AllOnes()); |
771 | } |
772 | |
773 | } // namespace SDPatternMatch |
774 | } // namespace llvm |
775 | #endif |
776 | |