1 | //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements an applicator that applies pattern rewrites based upon a |
10 | // user defined cost model. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Rewrite/PatternApplicator.h" |
15 | #include "ByteCode.h" |
16 | #include "llvm/Support/Debug.h" |
17 | |
18 | #define DEBUG_TYPE "pattern-application" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::detail; |
22 | |
23 | PatternApplicator::PatternApplicator( |
24 | const FrozenRewritePatternSet &frozenPatternList) |
25 | : frozenPatternList(frozenPatternList) { |
26 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
27 | mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); |
28 | bytecode->initializeMutableState(state&: *mutableByteCodeState); |
29 | } |
30 | } |
31 | PatternApplicator::~PatternApplicator() = default; |
32 | |
33 | #ifndef NDEBUG |
34 | /// Log a message for a pattern that is impossible to match. |
35 | static void logImpossibleToMatch(const Pattern &pattern) { |
36 | llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind() |
37 | << "' because it is impossible to match or cannot lead " |
38 | "to legal IR (by cost model)\n" ; |
39 | } |
40 | |
41 | /// Log IR after pattern application. |
42 | static Operation *getDumpRootOp(Operation *op) { |
43 | Operation *isolatedParent = |
44 | op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>(); |
45 | if (isolatedParent) |
46 | return isolatedParent; |
47 | return op; |
48 | } |
49 | static void logSucessfulPatternApplication(Operation *op) { |
50 | llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n" ; |
51 | op->dump(); |
52 | llvm::dbgs() << "\n\n" ; |
53 | } |
54 | #endif |
55 | |
56 | void PatternApplicator::applyCostModel(CostModel model) { |
57 | // Apply the cost model to the bytecode patterns first, and then the native |
58 | // patterns. |
59 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
60 | for (const auto &it : llvm::enumerate(First: bytecode->getPatterns())) |
61 | mutableByteCodeState->updatePatternBenefit(patternIndex: it.index(), benefit: model(it.value())); |
62 | } |
63 | |
64 | // Copy over the patterns so that we can sort by benefit based on the cost |
65 | // model. Patterns that are already impossible to match are ignored. |
66 | patterns.clear(); |
67 | for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) { |
68 | for (const RewritePattern *pattern : it.second) { |
69 | if (pattern->getBenefit().isImpossibleToMatch()) |
70 | LLVM_DEBUG(logImpossibleToMatch(*pattern)); |
71 | else |
72 | patterns[it.first].push_back(Elt: pattern); |
73 | } |
74 | } |
75 | anyOpPatterns.clear(); |
76 | for (const RewritePattern &pattern : |
77 | frozenPatternList.getMatchAnyOpNativePatterns()) { |
78 | if (pattern.getBenefit().isImpossibleToMatch()) |
79 | LLVM_DEBUG(logImpossibleToMatch(pattern)); |
80 | else |
81 | anyOpPatterns.push_back(Elt: &pattern); |
82 | } |
83 | |
84 | // Sort the patterns using the provided cost model. |
85 | llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; |
86 | auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { |
87 | return benefits[lhs] > benefits[rhs]; |
88 | }; |
89 | auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { |
90 | // Special case for one pattern in the list, which is the most common case. |
91 | if (list.size() == 1) { |
92 | if (model(*list.front()).isImpossibleToMatch()) { |
93 | LLVM_DEBUG(logImpossibleToMatch(*list.front())); |
94 | list.clear(); |
95 | } |
96 | return; |
97 | } |
98 | |
99 | // Collect the dynamic benefits for the current pattern list. |
100 | benefits.clear(); |
101 | for (const Pattern *pat : list) |
102 | benefits.try_emplace(Key: pat, Args: model(*pat)); |
103 | |
104 | // Sort patterns with highest benefit first, and remove those that are |
105 | // impossible to match. |
106 | std::stable_sort(first: list.begin(), last: list.end(), comp: cmp); |
107 | while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { |
108 | LLVM_DEBUG(logImpossibleToMatch(*list.back())); |
109 | list.pop_back(); |
110 | } |
111 | }; |
112 | for (auto &it : patterns) |
113 | processPatternList(it.second); |
114 | processPatternList(anyOpPatterns); |
115 | } |
116 | |
117 | void PatternApplicator::walkAllPatterns( |
118 | function_ref<void(const Pattern &)> walk) { |
119 | for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) |
120 | for (const auto &pattern : it.second) |
121 | walk(*pattern); |
122 | for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns()) |
123 | walk(it); |
124 | if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { |
125 | for (const Pattern &it : bytecode->getPatterns()) |
126 | walk(it); |
127 | } |
128 | } |
129 | |
130 | LogicalResult PatternApplicator::matchAndRewrite( |
131 | Operation *op, PatternRewriter &rewriter, |
132 | function_ref<bool(const Pattern &)> canApply, |
133 | function_ref<void(const Pattern &)> onFailure, |
134 | function_ref<LogicalResult(const Pattern &)> onSuccess) { |
135 | // Before checking native patterns, first match against the bytecode. This |
136 | // won't automatically perform any rewrites so there is no need to worry about |
137 | // conflicts. |
138 | SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; |
139 | const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); |
140 | if (bytecode) |
141 | bytecode->match(op, rewriter, matches&: pdlMatches, state&: *mutableByteCodeState); |
142 | |
143 | // Check to see if there are patterns matching this specific operation type. |
144 | MutableArrayRef<const RewritePattern *> opPatterns; |
145 | auto patternIt = patterns.find(Val: op->getName()); |
146 | if (patternIt != patterns.end()) |
147 | opPatterns = patternIt->second; |
148 | |
149 | // Process the patterns for that match the specific operation type, and any |
150 | // operation type in an interleaved fashion. |
151 | unsigned opIt = 0, opE = opPatterns.size(); |
152 | unsigned anyIt = 0, anyE = anyOpPatterns.size(); |
153 | unsigned pdlIt = 0, pdlE = pdlMatches.size(); |
154 | LogicalResult result = failure(); |
155 | do { |
156 | // Find the next pattern with the highest benefit. |
157 | const Pattern *bestPattern = nullptr; |
158 | unsigned *bestPatternIt = &opIt; |
159 | |
160 | /// Operation specific patterns. |
161 | if (opIt < opE) |
162 | bestPattern = opPatterns[opIt]; |
163 | /// Operation agnostic patterns. |
164 | if (anyIt < anyE && |
165 | (!bestPattern || |
166 | bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) { |
167 | bestPatternIt = &anyIt; |
168 | bestPattern = anyOpPatterns[anyIt]; |
169 | } |
170 | |
171 | const PDLByteCode::MatchResult *pdlMatch = nullptr; |
172 | /// PDL patterns. |
173 | if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < |
174 | pdlMatches[pdlIt].benefit)) { |
175 | bestPatternIt = &pdlIt; |
176 | pdlMatch = &pdlMatches[pdlIt]; |
177 | bestPattern = pdlMatch->pattern; |
178 | } |
179 | |
180 | if (!bestPattern) |
181 | break; |
182 | |
183 | // Update the pattern iterator on failure so that this pattern isn't |
184 | // attempted again. |
185 | ++(*bestPatternIt); |
186 | |
187 | // Check that the pattern can be applied. |
188 | if (canApply && !canApply(*bestPattern)) |
189 | continue; |
190 | |
191 | // Try to match and rewrite this pattern. The patterns are sorted by |
192 | // benefit, so if we match we can immediately rewrite. For PDL patterns, the |
193 | // match has already been performed, we just need to rewrite. |
194 | bool matched = false; |
195 | op->getContext()->executeAction<ApplyPatternAction>( |
196 | actionFn: [&]() { |
197 | rewriter.setInsertionPoint(op); |
198 | #ifndef NDEBUG |
199 | // Operation `op` may be invalidated after applying the rewrite |
200 | // pattern. |
201 | Operation *dumpRootOp = getDumpRootOp(op); |
202 | #endif |
203 | if (pdlMatch) { |
204 | result = |
205 | bytecode->rewrite(rewriter, match: *pdlMatch, state&: *mutableByteCodeState); |
206 | } else { |
207 | LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" |
208 | << bestPattern->getDebugName() << "\"\n" ); |
209 | |
210 | const auto *pattern = |
211 | static_cast<const RewritePattern *>(bestPattern); |
212 | result = pattern->matchAndRewrite(op, rewriter); |
213 | |
214 | LLVM_DEBUG(llvm::dbgs() |
215 | << "\"" << bestPattern->getDebugName() << "\" result " |
216 | << succeeded(result) << "\n" ); |
217 | } |
218 | |
219 | // Process the result of the pattern application. |
220 | if (succeeded(result) && onSuccess && failed(result: onSuccess(*bestPattern))) |
221 | result = failure(); |
222 | if (succeeded(result)) { |
223 | LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp)); |
224 | matched = true; |
225 | return; |
226 | } |
227 | |
228 | // Perform any necessary cleanups. |
229 | if (onFailure) |
230 | onFailure(*bestPattern); |
231 | }, |
232 | irUnits: {op}, args: *bestPattern); |
233 | if (matched) |
234 | break; |
235 | } while (true); |
236 | |
237 | if (mutableByteCodeState) |
238 | mutableByteCodeState->cleanupAfterMatchAndRewrite(); |
239 | return result; |
240 | } |
241 | |