1//===- RegionUtils.h - Region-related transformation utilities --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_TRANSFORMS_REGIONUTILS_H_
10#define MLIR_TRANSFORMS_REGIONUTILS_H_
11
12#include "mlir/IR/Region.h"
13#include "mlir/IR/Value.h"
14
15#include "llvm/ADT/SetVector.h"
16
17namespace mlir {
18class RewriterBase;
19
20/// Check if all values in the provided range are defined above the `limit`
21/// region. That is, if they are defined in a region that is a proper ancestor
22/// of `limit`.
23template <typename Range>
24bool areValuesDefinedAbove(Range values, Region &limit) {
25 for (Value v : values)
26 if (!v.getParentRegion()->isProperAncestor(other: &limit))
27 return false;
28 return true;
29}
30
31/// Replace all uses of `orig` within the given region with `replacement`.
32void replaceAllUsesInRegionWith(Value orig, Value replacement, Region &region);
33
34/// Calls `callback` for each use of a value within `region` or its descendants
35/// that was defined at the ancestors of the `limit`.
36void visitUsedValuesDefinedAbove(Region &region, Region &limit,
37 function_ref<void(OpOperand *)> callback);
38
39/// Calls `callback` for each use of a value within any of the regions provided
40/// that was defined in one of the ancestors.
41void visitUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
42 function_ref<void(OpOperand *)> callback);
43
44/// Fill `values` with a list of values defined at the ancestors of the `limit`
45/// region and used within `region` or its descendants.
46void getUsedValuesDefinedAbove(Region &region, Region &limit,
47 SetVector<Value> &values);
48
49/// Fill `values` with a list of values used within any of the regions provided
50/// but defined in one of the ancestors.
51void getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
52 SetVector<Value> &values);
53
54/// Make a region isolated from above
55/// - Capture the values that are defined above the region and used within it.
56/// - Append to the entry block arguments that represent the captured values
57/// (one per captured value).
58/// - Replace all uses within the region of the captured values with the
59/// newly added arguments.
60/// - `cloneOperationIntoRegion` is a callback that allows caller to specify
61/// if the operation defining an `OpOperand` needs to be cloned into the
62/// region. Then the operands of this operation become part of the captured
63/// values set (unless the operations that define the operands themeselves
64/// are to be cloned). The cloned operations are added to the entry block
65/// of the region.
66/// Return the set of captured values for the operation.
67SmallVector<Value> makeRegionIsolatedFromAbove(
68 RewriterBase &rewriter, Region &region,
69 llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
70 [](Operation *) { return false; });
71
72/// Run a set of structural simplifications over the given regions. This
73/// includes transformations like unreachable block elimination, dead argument
74/// elimination, as well as some other DCE. This function returns success if any
75/// of the regions were simplified, failure otherwise. The provided rewriter is
76/// used to notify callers of operation and block deletion.
77LogicalResult simplifyRegions(RewriterBase &rewriter,
78 MutableArrayRef<Region> regions);
79
80/// Erase the unreachable blocks within the provided regions. Returns success
81/// if any blocks were erased, failure otherwise.
82LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
83 MutableArrayRef<Region> regions);
84
85/// This function returns success if any operations or arguments were deleted,
86/// failure otherwise.
87LogicalResult runRegionDCE(RewriterBase &rewriter,
88 MutableArrayRef<Region> regions);
89
90/// Get a topologically sorted list of blocks of the given region.
91SetVector<Block *> getTopologicallySortedBlocks(Region &region);
92
93} // namespace mlir
94
95#endif // MLIR_TRANSFORMS_REGIONUTILS_H_
96

source code of mlir/include/mlir/Transforms/RegionUtils.h