1//===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines utilities for walking and visiting operations.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_VISITORS_H
14#define MLIR_IR_VISITORS_H
15
16#include "mlir/Support/LLVM.h"
17#include "mlir/Support/LogicalResult.h"
18#include "llvm/ADT/STLExtras.h"
19
20namespace mlir {
21class Diagnostic;
22class InFlightDiagnostic;
23class Operation;
24class Block;
25class Region;
26
27/// A utility result that is used to signal how to proceed with an ongoing walk:
28/// * Interrupt: the walk will be interrupted and no more operations, regions
29/// or blocks will be visited.
30/// * Advance: the walk will continue.
31/// * Skip: the walk of the current operation, region or block and their
32/// nested elements that haven't been visited already will be skipped and will
33/// continue with the next operation, region or block.
34class WalkResult {
35 enum ResultEnum { Interrupt, Advance, Skip } result;
36
37public:
38 WalkResult(ResultEnum result = Advance) : result(result) {}
39
40 /// Allow LogicalResult to interrupt the walk on failure.
41 WalkResult(LogicalResult result)
42 : result(failed(result) ? Interrupt : Advance) {}
43
44 /// Allow diagnostics to interrupt the walk.
45 WalkResult(Diagnostic &&) : result(Interrupt) {}
46 WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
47
48 bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
49 bool operator!=(const WalkResult &rhs) const { return result != rhs.result; }
50
51 static WalkResult interrupt() { return {Interrupt}; }
52 static WalkResult advance() { return {Advance}; }
53 static WalkResult skip() { return {Skip}; }
54
55 /// Returns true if the walk was interrupted.
56 bool wasInterrupted() const { return result == Interrupt; }
57
58 /// Returns true if the walk was skipped.
59 bool wasSkipped() const { return result == Skip; }
60};
61
62/// Traversal order for region, block and operation walk utilities.
63enum class WalkOrder { PreOrder, PostOrder };
64
65/// This iterator enumerates the elements in "forward" order.
66struct ForwardIterator {
67 /// Make operations iterable: return the list of regions.
68 static MutableArrayRef<Region> makeIterable(Operation &range);
69
70 /// Regions and block are already iterable.
71 template <typename T>
72 static constexpr T &makeIterable(T &range) {
73 return range;
74 }
75};
76
77/// A utility class to encode the current walk stage for "generic" walkers.
78/// When walking an operation, we can either choose a Pre/Post order walker
79/// which invokes the callback on an operation before/after all its attached
80/// regions have been visited, or choose a "generic" walker where the callback
81/// is invoked on the operation N+1 times where N is the number of regions
82/// attached to that operation. The `WalkStage` class below encodes the current
83/// stage of the walk, i.e., which regions have already been visited, and the
84/// callback accepts an additional argument for the current stage. Such
85/// generic walkers that accept stage-aware callbacks are only applicable when
86/// the callback operates on an operation (i.e., not applicable for callbacks
87/// on Blocks or Regions).
88class WalkStage {
89public:
90 explicit WalkStage(Operation *op);
91
92 /// Return true if parent operation is being visited before all regions.
93 bool isBeforeAllRegions() const { return nextRegion == 0; }
94 /// Returns true if parent operation is being visited just before visiting
95 /// region number `region`.
96 bool isBeforeRegion(int region) const { return nextRegion == region; }
97 /// Returns true if parent operation is being visited just after visiting
98 /// region number `region`.
99 bool isAfterRegion(int region) const { return nextRegion == region + 1; }
100 /// Return true if parent operation is being visited after all regions.
101 bool isAfterAllRegions() const { return nextRegion == numRegions; }
102 /// Advance the walk stage.
103 void advance() { nextRegion++; }
104 /// Returns the next region that will be visited.
105 int getNextRegion() const { return nextRegion; }
106
107private:
108 const int numRegions;
109 int nextRegion;
110};
111
112namespace detail {
113/// Helper templates to deduce the first argument of a callback parameter.
114template <typename Ret, typename Arg, typename... Rest>
115Arg first_argument_type(Ret (*)(Arg, Rest...));
116template <typename Ret, typename F, typename Arg, typename... Rest>
117Arg first_argument_type(Ret (F::*)(Arg, Rest...));
118template <typename Ret, typename F, typename Arg, typename... Rest>
119Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
120template <typename F>
121decltype(first_argument_type(&F::operator())) first_argument_type(F);
122
123/// Type definition of the first argument to the given callable 'T'.
124template <typename T>
125using first_argument = decltype(first_argument_type(std::declval<T>()));
126
127/// Walk all of the regions, blocks, or operations nested under (and including)
128/// the given operation. The order in which regions, blocks and operations at
129/// the same nesting level are visited (e.g., lexicographical or reverse
130/// lexicographical order) is determined by 'Iterator'. The walk order for
131/// enclosing regions, blocks and operations with respect to their nested ones
132/// is specified by 'order'. These methods are invoked for void-returning
133/// callbacks. A callback on a block or operation is allowed to erase that block
134/// or operation only if the walk is in post-order. See non-void method for
135/// pre-order erasure.
136template <typename Iterator>
137void walk(Operation *op, function_ref<void(Region *)> callback,
138 WalkOrder order) {
139 // We don't use early increment for regions because they can't be erased from
140 // a callback.
141 for (auto &region : Iterator::makeIterable(*op)) {
142 if (order == WalkOrder::PreOrder)
143 callback(&region);
144 for (auto &block : Iterator::makeIterable(region)) {
145 for (auto &nestedOp : Iterator::makeIterable(block))
146 walk<Iterator>(&nestedOp, callback, order);
147 }
148 if (order == WalkOrder::PostOrder)
149 callback(&region);
150 }
151}
152
153template <typename Iterator>
154void walk(Operation *op, function_ref<void(Block *)> callback,
155 WalkOrder order) {
156 for (auto &region : Iterator::makeIterable(*op)) {
157 // Early increment here in the case where the block is erased.
158 for (auto &block :
159 llvm::make_early_inc_range(Iterator::makeIterable(region))) {
160 if (order == WalkOrder::PreOrder)
161 callback(&block);
162 for (auto &nestedOp : Iterator::makeIterable(block))
163 walk<Iterator>(&nestedOp, callback, order);
164 if (order == WalkOrder::PostOrder)
165 callback(&block);
166 }
167 }
168}
169
170template <typename Iterator>
171void walk(Operation *op, function_ref<void(Operation *)> callback,
172 WalkOrder order) {
173 if (order == WalkOrder::PreOrder)
174 callback(op);
175
176 // TODO: This walk should be iterative over the operations.
177 for (auto &region : Iterator::makeIterable(*op)) {
178 for (auto &block : Iterator::makeIterable(region)) {
179 // Early increment here in the case where the operation is erased.
180 for (auto &nestedOp :
181 llvm::make_early_inc_range(Iterator::makeIterable(block)))
182 walk<Iterator>(&nestedOp, callback, order);
183 }
184 }
185
186 if (order == WalkOrder::PostOrder)
187 callback(op);
188}
189
190/// Walk all of the regions, blocks, or operations nested under (and including)
191/// the given operation. The order in which regions, blocks and operations at
192/// the same nesting level are visited (e.g., lexicographical or reverse
193/// lexicographical order) is determined by 'Iterator'. The walk order for
194/// enclosing regions, blocks and operations with respect to their nested ones
195/// is specified by 'order'. This method is invoked for skippable or
196/// interruptible callbacks. A callback on a block or operation is allowed to
197/// erase that block or operation if either:
198/// * the walk is in post-order, or
199/// * the walk is in pre-order and the walk is skipped after the erasure.
200template <typename Iterator>
201WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
202 WalkOrder order) {
203 // We don't use early increment for regions because they can't be erased from
204 // a callback.
205 for (auto &region : Iterator::makeIterable(*op)) {
206 if (order == WalkOrder::PreOrder) {
207 WalkResult result = callback(&region);
208 if (result.wasSkipped())
209 continue;
210 if (result.wasInterrupted())
211 return WalkResult::interrupt();
212 }
213 for (auto &block : Iterator::makeIterable(region)) {
214 for (auto &nestedOp : Iterator::makeIterable(block))
215 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
216 return WalkResult::interrupt();
217 }
218 if (order == WalkOrder::PostOrder) {
219 if (callback(&region).wasInterrupted())
220 return WalkResult::interrupt();
221 // We don't check if this region was skipped because its walk already
222 // finished and the walk will continue with the next region.
223 }
224 }
225 return WalkResult::advance();
226}
227
228template <typename Iterator>
229WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
230 WalkOrder order) {
231 for (auto &region : Iterator::makeIterable(*op)) {
232 // Early increment here in the case where the block is erased.
233 for (auto &block :
234 llvm::make_early_inc_range(Iterator::makeIterable(region))) {
235 if (order == WalkOrder::PreOrder) {
236 WalkResult result = callback(&block);
237 if (result.wasSkipped())
238 continue;
239 if (result.wasInterrupted())
240 return WalkResult::interrupt();
241 }
242 for (auto &nestedOp : Iterator::makeIterable(block))
243 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
244 return WalkResult::interrupt();
245 if (order == WalkOrder::PostOrder) {
246 if (callback(&block).wasInterrupted())
247 return WalkResult::interrupt();
248 // We don't check if this block was skipped because its walk already
249 // finished and the walk will continue with the next block.
250 }
251 }
252 }
253 return WalkResult::advance();
254}
255
256template <typename Iterator>
257WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
258 WalkOrder order) {
259 if (order == WalkOrder::PreOrder) {
260 WalkResult result = callback(op);
261 // If skipped, caller will continue the walk on the next operation.
262 if (result.wasSkipped())
263 return WalkResult::advance();
264 if (result.wasInterrupted())
265 return WalkResult::interrupt();
266 }
267
268 // TODO: This walk should be iterative over the operations.
269 for (auto &region : Iterator::makeIterable(*op)) {
270 for (auto &block : Iterator::makeIterable(region)) {
271 // Early increment here in the case where the operation is erased.
272 for (auto &nestedOp :
273 llvm::make_early_inc_range(Iterator::makeIterable(block))) {
274 if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
275 return WalkResult::interrupt();
276 }
277 }
278 }
279
280 if (order == WalkOrder::PostOrder)
281 return callback(op);
282 return WalkResult::advance();
283}
284
285// Below are a set of functions to walk nested operations. Users should favor
286// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
287// methods. They are also templated to allow for statically dispatching based
288// upon the type of the callback function.
289
290/// Walk all of the regions, blocks, or operations nested under (and including)
291/// the given operation. The order in which regions, blocks and operations at
292/// the same nesting level are visited (e.g., lexicographical or reverse
293/// lexicographical order) is determined by 'Iterator'. The walk order for
294/// enclosing regions, blocks and operations with respect to their nested ones
295/// is specified by 'Order' (post-order by default). A callback on a block or
296/// operation is allowed to erase that block or operation if either:
297/// * the walk is in post-order, or
298/// * the walk is in pre-order and the walk is skipped after the erasure.
299/// This method is selected for callbacks that operate on Region*, Block*, and
300/// Operation*.
301///
302/// Example:
303/// op->walk([](Region *r) { ... });
304/// op->walk([](Block *b) { ... });
305/// op->walk([](Operation *op) { ... });
306template <
307 WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
308 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
309 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
310std::enable_if_t<llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value,
311 RetT>
312walk(Operation *op, FuncTy &&callback) {
313 return detail::walk<Iterator>(op, function_ref<RetT(ArgT)>(callback), Order);
314}
315
316/// Walk all of the operations of type 'ArgT' nested under and including the
317/// given operation. The order in which regions, blocks and operations at
318/// the same nesting are visited (e.g., lexicographical or reverse
319/// lexicographical order) is determined by 'Iterator'. The walk order for
320/// enclosing regions, blocks and operations with respect to their nested ones
321/// is specified by 'order' (post-order by default). This method is selected for
322/// void-returning callbacks that operate on a specific derived operation type.
323/// A callback on an operation is allowed to erase that operation only if the
324/// walk is in post-order. See non-void method for pre-order erasure.
325///
326/// Example:
327/// op->walk([](ReturnOp op) { ... });
328template <
329 WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
330 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
331 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
332std::enable_if_t<
333 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
334 std::is_same<RetT, void>::value,
335 RetT>
336walk(Operation *op, FuncTy &&callback) {
337 auto wrapperFn = [&](Operation *op) {
338 if (auto derivedOp = dyn_cast<ArgT>(op))
339 callback(derivedOp);
340 };
341 return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
342 Order);
343}
344
345/// Walk all of the operations of type 'ArgT' nested under and including the
346/// given operation. The order in which regions, blocks and operations at
347/// the same nesting are visited (e.g., lexicographical or reverse
348/// lexicographical order) is determined by 'Iterator'. The walk order for
349/// enclosing regions, blocks and operations with respect to their nested ones
350/// is specified by 'Order' (post-order by default). This method is selected for
351/// WalkReturn returning skippable or interruptible callbacks that operate on a
352/// specific derived operation type. A callback on an operation is allowed to
353/// erase that operation if either:
354/// * the walk is in post-order, or
355/// * the walk is in pre-order and the walk is skipped after the erasure.
356///
357/// Example:
358/// op->walk([](ReturnOp op) {
359/// if (some_invariant)
360/// return WalkResult::skip();
361/// if (another_invariant)
362/// return WalkResult::interrupt();
363/// return WalkResult::advance();
364/// });
365template <
366 WalkOrder Order = WalkOrder::PostOrder, typename Iterator = ForwardIterator,
367 typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
368 typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
369std::enable_if_t<
370 !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
371 std::is_same<RetT, WalkResult>::value,
372 RetT>
373walk(Operation *op, FuncTy &&callback) {
374 auto wrapperFn = [&](Operation *op) {
375 if (auto derivedOp = dyn_cast<ArgT>(op))
376 return callback(derivedOp);
377 return WalkResult::advance();
378 };
379 return detail::walk<Iterator>(op, function_ref<RetT(Operation *)>(wrapperFn),
380 Order);
381}
382
383/// Generic walkers with stage aware callbacks.
384
385/// Walk all the operations nested under (and including) the given operation,
386/// with the callback being invoked on each operation N+1 times, where N is the
387/// number of regions attached to the operation. The `stage` input to the
388/// callback indicates the current walk stage. This method is invoked for void
389/// returning callbacks.
390void walk(Operation *op,
391 function_ref<void(Operation *, const WalkStage &stage)> callback);
392
393/// Walk all the operations nested under (and including) the given operation,
394/// with the callback being invoked on each operation N+1 times, where N is the
395/// number of regions attached to the operation. The `stage` input to the
396/// callback indicates the current walk stage. This method is invoked for
397/// skippable or interruptible callbacks.
398WalkResult
399walk(Operation *op,
400 function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
401
402/// Walk all of the operations nested under and including the given operation.
403/// This method is selected for stage-aware callbacks that operate on
404/// Operation*.
405///
406/// Example:
407/// op->walk([](Operation *op, const WalkStage &stage) { ... });
408template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
409 typename RetT = decltype(std::declval<FuncTy>()(
410 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
411std::enable_if_t<std::is_same<ArgT, Operation *>::value, RetT>
412walk(Operation *op, FuncTy &&callback) {
413 return detail::walk(op,
414 function_ref<RetT(ArgT, const WalkStage &)>(callback));
415}
416
417/// Walk all of the operations of type 'ArgT' nested under and including the
418/// given operation. This method is selected for void returning callbacks that
419/// operate on a specific derived operation type.
420///
421/// Example:
422/// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
423template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
424 typename RetT = decltype(std::declval<FuncTy>()(
425 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
426std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
427 std::is_same<RetT, void>::value,
428 RetT>
429walk(Operation *op, FuncTy &&callback) {
430 auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
431 if (auto derivedOp = dyn_cast<ArgT>(op))
432 callback(derivedOp, stage);
433 };
434 return detail::walk(
435 op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
436}
437
438/// Walk all of the operations of type 'ArgT' nested under and including the
439/// given operation. This method is selected for WalkReturn returning
440/// interruptible callbacks that operate on a specific derived operation type.
441///
442/// Example:
443/// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
444/// if (some_invariant)
445/// return WalkResult::interrupt();
446/// return WalkResult::advance();
447/// });
448template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
449 typename RetT = decltype(std::declval<FuncTy>()(
450 std::declval<ArgT>(), std::declval<const WalkStage &>()))>
451std::enable_if_t<!std::is_same<ArgT, Operation *>::value &&
452 std::is_same<RetT, WalkResult>::value,
453 RetT>
454walk(Operation *op, FuncTy &&callback) {
455 auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
456 if (auto derivedOp = dyn_cast<ArgT>(op))
457 return callback(derivedOp, stage);
458 return WalkResult::advance();
459 };
460 return detail::walk(
461 op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
462}
463
464/// Utility to provide the return type of a templated walk method.
465template <typename FnT>
466using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
467} // namespace detail
468
469} // namespace mlir
470
471#endif
472

source code of mlir/include/mlir/IR/Visitors.h