1 | //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines utilities for walking and visiting operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_VISITORS_H |
14 | #define MLIR_IR_VISITORS_H |
15 | |
16 | #include "mlir/Support/LLVM.h" |
17 | #include "mlir/Support/LogicalResult.h" |
18 | #include "llvm/ADT/STLExtras.h" |
19 | |
20 | namespace mlir { |
21 | class Diagnostic; |
22 | class InFlightDiagnostic; |
23 | class Operation; |
24 | class Block; |
25 | class Region; |
26 | |
27 | /// A utility result that is used to signal how to proceed with an ongoing walk: |
28 | /// * Interrupt: the walk will be interrupted and no more operations, regions |
29 | /// or blocks will be visited. |
30 | /// * Advance: the walk will continue. |
31 | /// * Skip: the walk of the current operation, region or block and their |
32 | /// nested elements that haven't been visited already will be skipped and will |
33 | /// continue with the next operation, region or block. |
34 | class WalkResult { |
35 | enum ResultEnum { Interrupt, Advance, Skip } result; |
36 | |
37 | public: |
38 | WalkResult(ResultEnum result = Advance) : result(result) {} |
39 | |
40 | /// Allow LogicalResult to interrupt the walk on failure. |
41 | WalkResult(LogicalResult result) |
42 | : result(failed(result) ? Interrupt : Advance) {} |
43 | |
44 | /// Allow diagnostics to interrupt the walk. |
45 | WalkResult(Diagnostic &&) : result(Interrupt) {} |
46 | WalkResult(InFlightDiagnostic &&) : result(Interrupt) {} |
47 | |
48 | bool operator==(const WalkResult &rhs) const { return result == rhs.result; } |
49 | bool operator!=(const WalkResult &rhs) const { return result != rhs.result; } |
50 | |
51 | static WalkResult interrupt() { return {Interrupt}; } |
52 | static WalkResult advance() { return {Advance}; } |
53 | static WalkResult skip() { return {Skip}; } |
54 | |
55 | /// Returns true if the walk was interrupted. |
56 | bool wasInterrupted() const { return result == Interrupt; } |
57 | |
58 | /// Returns true if the walk was skipped. |
59 | bool wasSkipped() const { return result == Skip; } |
60 | }; |
61 | |
62 | /// Traversal order for region, block and operation walk utilities. |
63 | enum class WalkOrder { PreOrder, PostOrder }; |
64 | |
65 | /// This iterator enumerates the elements in "forward" order. |
66 | struct ForwardIterator { |
67 | /// Make operations iterable: return the list of regions. |
68 | static MutableArrayRef<Region> makeIterable(Operation &range); |
69 | |
70 | /// Regions and block are already iterable. |
71 | template <typename T> |
72 | static constexpr T &makeIterable(T &range) { |
73 | return range; |
74 | } |
75 | }; |
76 | |
77 | /// A utility class to encode the current walk stage for "generic" walkers. |
78 | /// When walking an operation, we can either choose a Pre/Post order walker |
79 | /// which invokes the callback on an operation before/after all its attached |
80 | /// regions have been visited, or choose a "generic" walker where the callback |
81 | /// is invoked on the operation N+1 times where N is the number of regions |
82 | /// attached to that operation. The `WalkStage` class below encodes the current |
83 | /// stage of the walk, i.e., which regions have already been visited, and the |
84 | /// callback accepts an additional argument for the current stage. Such |
85 | /// generic walkers that accept stage-aware callbacks are only applicable when |
86 | /// the callback operates on an operation (i.e., not applicable for callbacks |
87 | /// on Blocks or Regions). |
88 | class WalkStage { |
89 | public: |
90 | explicit WalkStage(Operation *op); |
91 | |
92 | /// Return true if parent operation is being visited before all regions. |
93 | bool isBeforeAllRegions() const { return nextRegion == 0; } |
94 | /// Returns true if parent operation is being visited just before visiting |
95 | /// region number `region`. |
96 | bool isBeforeRegion(int region) const { return nextRegion == region; } |
97 | /// Returns true if parent operation is being visited just after visiting |
98 | /// region number `region`. |
99 | bool isAfterRegion(int region) const { return nextRegion == region + 1; } |
100 | /// Return true if parent operation is being visited after all regions. |
101 | bool isAfterAllRegions() const { return nextRegion == numRegions; } |
102 | /// Advance the walk stage. |
103 | void advance() { nextRegion++; } |
104 | /// Returns the next region that will be visited. |
105 | int getNextRegion() const { return nextRegion; } |
106 | |
107 | private: |
108 | const int numRegions; |
109 | int nextRegion; |
110 | }; |
111 | |
112 | namespace detail { |
113 | /// Helper templates to deduce the first argument of a callback parameter. |
114 | template <typename Ret, typename Arg, typename... Rest> |
115 | Arg first_argument_type(Ret (*)(Arg, Rest...)); |
116 | template <typename Ret, typename F, typename Arg, typename... Rest> |
117 | Arg first_argument_type(Ret (F::*)(Arg, Rest...)); |
118 | template <typename Ret, typename F, typename Arg, typename... Rest> |
119 | Arg first_argument_type(Ret (F::*)(Arg, Rest...) const); |
120 | template <typename F> |
121 | decltype(first_argument_type(&F::operator())) first_argument_type(F); |
122 | |
123 | /// Type definition of the first argument to the given callable 'T'. |
124 | template <typename T> |
125 | using first_argument = decltype(first_argument_type(std::declval<T>())); |
126 | |
127 | /// Walk all of the regions, blocks, or operations nested under (and including) |
128 | /// the given operation. The order in which regions, blocks and operations at |
129 | /// the same nesting level are visited (e.g., lexicographical or reverse |
130 | /// lexicographical order) is determined by 'Iterator'. The walk order for |
131 | /// enclosing regions, blocks and operations with respect to their nested ones |
132 | /// is specified by 'order'. These methods are invoked for void-returning |
133 | /// callbacks. A callback on a block or operation is allowed to erase that block |
134 | /// or operation only if the walk is in post-order. See non-void method for |
135 | /// pre-order erasure. |
136 | template <typename Iterator> |
137 | void walk(Operation *op, function_ref<void(Region *)> callback, |
138 | WalkOrder order) { |
139 | // We don't use early increment for regions because they can't be erased from |
140 | // a callback. |
141 | for (auto ®ion : Iterator::makeIterable(*op)) { |
142 | if (order == WalkOrder::PreOrder) |
143 | callback(®ion); |
144 | for (auto &block : Iterator::makeIterable(region)) { |
145 | for (auto &nestedOp : Iterator::makeIterable(block)) |
146 | walk<Iterator>(&nestedOp, callback, order); |
147 | } |
148 | if (order == WalkOrder::PostOrder) |
149 | callback(®ion); |
150 | } |
151 | } |
152 | |
153 | template <typename Iterator> |
154 | void walk(Operation *op, function_ref<void(Block *)> callback, |
155 | WalkOrder order) { |
156 | for (auto ®ion : Iterator::makeIterable(*op)) { |
157 | // Early increment here in the case where the block is erased. |
158 | for (auto &block : |
159 | llvm::make_early_inc_range(Iterator::makeIterable(region))) { |
160 | if (order == WalkOrder::PreOrder) |
161 | callback(&block); |
162 | for (auto &nestedOp : Iterator::makeIterable(block)) |
163 | walk<Iterator>(&nestedOp, callback, order); |
164 | if (order == WalkOrder::PostOrder) |
165 | callback(&block); |
166 | } |
167 | } |
168 | } |
169 | |
170 | template <typename Iterator> |
171 | void walk(Operation *op, function_ref<void(Operation *)> callback, |
172 | WalkOrder order) { |
173 | if (order == WalkOrder::PreOrder) |
174 | callback(op); |
175 | |
176 | // TODO: This walk should be iterative over the operations. |
177 | for (auto ®ion : Iterator::makeIterable(*op)) { |
178 | for (auto &block : Iterator::makeIterable(region)) { |
179 | // Early increment here in the case where the operation is erased. |
180 | for (auto &nestedOp : |
181 | llvm::make_early_inc_range(Iterator::makeIterable(block))) |
182 | walk<Iterator>(&nestedOp, callback, order); |
183 | } |
184 | } |
185 | |
186 | if (order == WalkOrder::PostOrder) |
187 | callback(op); |
188 | } |
189 | |
190 | /// Walk all of the regions, blocks, or operations nested under (and including) |
191 | /// the given operation. The order in which regions, blocks and operations at |
192 | /// the same nesting level are visited (e.g., lexicographical or reverse |
193 | /// lexicographical order) is determined by 'Iterator'. The walk order for |
194 | /// enclosing regions, blocks and operations with respect to their nested ones |
195 | /// is specified by 'order'. This method is invoked for skippable or |
196 | /// interruptible callbacks. A callback on a block or operation is allowed to |
197 | /// erase that block or operation if either: |
198 | /// * the walk is in post-order, or |
199 | /// * the walk is in pre-order and the walk is skipped after the erasure. |
200 | template <typename Iterator> |
201 | WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback, |
202 | WalkOrder order) { |
203 | // We don't use early increment for regions because they can't be erased from |
204 | // a callback. |
205 | for (auto ®ion : Iterator::makeIterable(*op)) { |
206 | if (order == WalkOrder::PreOrder) { |
207 | WalkResult result = callback(®ion); |
208 | if (result.wasSkipped()) |
209 | continue; |
210 | if (result.wasInterrupted()) |
211 | return WalkResult::interrupt(); |
212 | } |
213 | for (auto &block : Iterator::makeIterable(region)) { |
214 | for (auto &nestedOp : Iterator::makeIterable(block)) |
215 | if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted()) |
216 | return WalkResult::interrupt(); |
217 | } |
218 | if (order == WalkOrder::PostOrder) { |
219 | if (callback(®ion).wasInterrupted()) |
220 | return WalkResult::interrupt(); |
221 | // We don't check if this region was skipped because its walk already |
222 | // finished and the walk will continue with the next region. |
223 | } |
224 | } |
225 | return WalkResult::advance(); |
226 | } |
227 | |
228 | template <typename Iterator> |
229 | WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback, |
230 | WalkOrder order) { |
231 | for (auto ®ion : Iterator::makeIterable(*op)) { |
232 | // Early increment here in the case where the block is erased. |
233 | for (auto &block : |
234 | llvm::make_early_inc_range(Iterator::makeIterable(region))) { |
235 | if (order == WalkOrder::PreOrder) { |
236 | WalkResult result = callback(&block); |
237 | if (result.wasSkipped()) |
238 | continue; |
239 | if (result.wasInterrupted()) |
240 | return WalkResult::interrupt(); |
241 | } |
242 | for (auto &nestedOp : Iterator::makeIterable(block)) |
243 | if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted()) |
244 | return WalkResult::interrupt(); |
245 | if (order == WalkOrder::PostOrder) { |
246 | if (callback(&block).wasInterrupted()) |
247 | return WalkResult::interrupt(); |
248 | // We don't check if this block was skipped because its walk already |
249 | // finished and the walk will continue with the next block. |
250 | } |
251 | } |
252 | } |
253 | return WalkResult::advance(); |
254 | } |
255 | |
256 | template <typename Iterator> |
257 | WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback, |
258 | WalkOrder order) { |
259 | if (order == WalkOrder::PreOrder) { |
260 | WalkResult result = callback(op); |
261 | // If skipped, caller will continue the walk on the next operation. |
262 | if (result.wasSkipped()) |
263 | return WalkResult::advance(); |
264 | if (result.wasInterrupted()) |
265 | return WalkResult::interrupt(); |
266 | } |
267 | |
268 | // TODO: This walk should be iterative over the operations. |
269 | for (auto ®ion : Iterator::makeIterable(*op)) { |
270 | for (auto &block : Iterator::makeIterable(region)) { |
271 | // Early increment here in the case where the operation is erased. |
272 | for (auto &nestedOp : |
273 | llvm::make_early_inc_range(Iterator::makeIterable(block))) { |
274 | if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted()) |
275 | return WalkResult::interrupt(); |
276 | } |
277 | } |
278 | } |
279 | |
280 | if (order == WalkOrder::PostOrder) |
281 | return callback(op); |
282 | return WalkResult::advance(); |
283 | } |
284 | |
285 | // Below are a set of functions to walk nested operations. Users should favor |
286 | // the direct `walk` methods on the IR classes(Operation/Block/etc) over these |
287 | // methods. They are also templated to allow for statically dispatching based |
288 | // upon the type of the callback function. |
289 | |
290 | /// Walk all of the regions, blocks, or operations nested under (and including) |
291 | /// the given operation. The order in which regions, blocks and operations at |
292 | /// the same nesting level are visited (e.g., lexicographical or reverse |
293 | /// lexicographical order) is determined by 'Iterator'. The walk order for |
294 | /// enclosing regions, blocks and operations with respect to their nested ones |
295 | /// is specified by 'Order' (post-order by default). A callback on a block or |
296 | /// operation is allowed to erase that block or operation if either: |
297 | /// * the walk is in post-order, or |
298 | /// * the walk is in pre-order and the walk is skipped after the erasure. |
299 | /// This method is selected for callbacks that operate on Region*, Block*, and |
300 | /// Operation*. |
301 | /// |
302 | /// Example: |
303 | /// op->walk([](Region *r) { ... }); |
304 | /// op->walk([](Block *b) { ... }); |
305 | /// op->walk([](Operation *op) { ... }); |
306 | template < |
307 | WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator, |
308 | typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, |
309 | typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))> |
310 | std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value, |
311 | RetT> |
312 | walk(Operation *op, FuncTy &&callback) { |
313 | return detail::walk<Iterator>(op, function_ref<RetT(ArgT)>(callback), Order); |
314 | } |
315 | |
316 | /// Walk all of the operations of type 'ArgT' nested under and including the |
317 | /// given operation. The order in which regions, blocks and operations at |
318 | /// the same nesting are visited (e.g., lexicographical or reverse |
319 | /// lexicographical order) is determined by 'Iterator'. The walk order for |
320 | /// enclosing regions, blocks and operations with respect to their nested ones |
321 | /// is specified by 'order' (post-order by default). This method is selected for |
322 | /// void-returning callbacks that operate on a specific derived operation type. |
323 | /// A callback on an operation is allowed to erase that operation only if the |
324 | /// walk is in post-order. See non-void method for pre-order erasure. |
325 | /// |
326 | /// Example: |
327 | /// op->walk([](ReturnOp op) { ... }); |
328 | template < |
329 | WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator, |
330 | typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, |
331 | typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))> |
332 | std::enable_if_t< |
333 | !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value && |
334 | std::is_same<RetT, void>::value, |
335 | RetT> |
336 | walk(Operation *op, FuncTy &&callback) { |
337 | auto wrapperFn = [&](Operation *op) { |
338 | if (auto derivedOp = dyn_cast<ArgT>(op)) |
339 | callback(derivedOp); |
340 | }; |
341 | return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn), |
342 | Order); |
343 | } |
344 | |
345 | /// Walk all of the operations of type 'ArgT' nested under and including the |
346 | /// given operation. The order in which regions, blocks and operations at |
347 | /// the same nesting are visited (e.g., lexicographical or reverse |
348 | /// lexicographical order) is determined by 'Iterator'. The walk order for |
349 | /// enclosing regions, blocks and operations with respect to their nested ones |
350 | /// is specified by 'Order' (post-order by default). This method is selected for |
351 | /// WalkReturn returning skippable or interruptible callbacks that operate on a |
352 | /// specific derived operation type. A callback on an operation is allowed to |
353 | /// erase that operation if either: |
354 | /// * the walk is in post-order, or |
355 | /// * the walk is in pre-order and the walk is skipped after the erasure. |
356 | /// |
357 | /// Example: |
358 | /// op->walk([](ReturnOp op) { |
359 | /// if (some_invariant) |
360 | /// return WalkResult::skip(); |
361 | /// if (another_invariant) |
362 | /// return WalkResult::interrupt(); |
363 | /// return WalkResult::advance(); |
364 | /// }); |
365 | template < |
366 | WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator, |
367 | typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, |
368 | typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))> |
369 | std::enable_if_t< |
370 | !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value && |
371 | std::is_same<RetT, WalkResult>::value, |
372 | RetT> |
373 | walk(Operation *op, FuncTy &&callback) { |
374 | auto wrapperFn = [&](Operation *op) { |
375 | if (auto derivedOp = dyn_cast<ArgT>(op)) |
376 | return callback(derivedOp); |
377 | return WalkResult::advance(); |
378 | }; |
379 | return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn), |
380 | Order); |
381 | } |
382 | |
383 | /// Generic walkers with stage aware callbacks. |
384 | |
385 | /// Walk all the operations nested under (and including) the given operation, |
386 | /// with the callback being invoked on each operation N+1 times, where N is the |
387 | /// number of regions attached to the operation. The `stage` input to the |
388 | /// callback indicates the current walk stage. This method is invoked for void |
389 | /// returning callbacks. |
390 | void walk(Operation *op, |
391 | function_ref<void(Operation *, const WalkStage &stage)> callback); |
392 | |
393 | /// Walk all the operations nested under (and including) the given operation, |
394 | /// with the callback being invoked on each operation N+1 times, where N is the |
395 | /// number of regions attached to the operation. The `stage` input to the |
396 | /// callback indicates the current walk stage. This method is invoked for |
397 | /// skippable or interruptible callbacks. |
398 | WalkResult |
399 | walk(Operation *op, |
400 | function_ref<WalkResult(Operation *, const WalkStage &stage)> callback); |
401 | |
402 | /// Walk all of the operations nested under and including the given operation. |
403 | /// This method is selected for stage-aware callbacks that operate on |
404 | /// Operation*. |
405 | /// |
406 | /// Example: |
407 | /// op->walk([](Operation *op, const WalkStage &stage) { ... }); |
408 | template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, |
409 | typename RetT = decltype(std::declval<FuncTy>()( |
410 | std::declval<ArgT>(), std::declval<const WalkStage &>()))> |
411 | std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT> |
412 | walk(Operation *op, FuncTy &&callback) { |
413 | return detail::walk(op, |
414 | function_ref<RetT(ArgT, const WalkStage &)>(callback)); |
415 | } |
416 | |
417 | /// Walk all of the operations of type 'ArgT' nested under and including the |
418 | /// given operation. This method is selected for void returning callbacks that |
419 | /// operate on a specific derived operation type. |
420 | /// |
421 | /// Example: |
422 | /// op->walk([](ReturnOp op, const WalkStage &stage) { ... }); |
423 | template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, |
424 | typename RetT = decltype(std::declval<FuncTy>()( |
425 | std::declval<ArgT>(), std::declval<const WalkStage &>()))> |
426 | std::enable_if_t<!std::is_same<ArgT, Operation *>::value && |
427 | std::is_same<RetT, void>::value, |
428 | RetT> |
429 | walk(Operation *op, FuncTy &&callback) { |
430 | auto wrapperFn = [&](Operation *op, const WalkStage &stage) { |
431 | if (auto derivedOp = dyn_cast<ArgT>(op)) |
432 | callback(derivedOp, stage); |
433 | }; |
434 | return detail::walk( |
435 | op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn)); |
436 | } |
437 | |
438 | /// Walk all of the operations of type 'ArgT' nested under and including the |
439 | /// given operation. This method is selected for WalkReturn returning |
440 | /// interruptible callbacks that operate on a specific derived operation type. |
441 | /// |
442 | /// Example: |
443 | /// op->walk(op, [](ReturnOp op, const WalkStage &stage) { |
444 | /// if (some_invariant) |
445 | /// return WalkResult::interrupt(); |
446 | /// return WalkResult::advance(); |
447 | /// }); |
448 | template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>, |
449 | typename RetT = decltype(std::declval<FuncTy>()( |
450 | std::declval<ArgT>(), std::declval<const WalkStage &>()))> |
451 | std::enable_if_t<!std::is_same<ArgT, Operation *>::value && |
452 | std::is_same<RetT, WalkResult>::value, |
453 | RetT> |
454 | walk(Operation *op, FuncTy &&callback) { |
455 | auto wrapperFn = [&](Operation *op, const WalkStage &stage) { |
456 | if (auto derivedOp = dyn_cast<ArgT>(op)) |
457 | return callback(derivedOp, stage); |
458 | return WalkResult::advance(); |
459 | }; |
460 | return detail::walk( |
461 | op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn)); |
462 | } |
463 | |
464 | /// Utility to provide the return type of a templated walk method. |
465 | template <typename FnT> |
466 | using walkResultType = decltype(walk(nullptr, std::declval<FnT>())); |
467 | } // namespace detail |
468 | |
469 | } // namespace mlir |
470 | |
471 | #endif |
472 | |