1//===- CompositePass.cpp - Composite pass code ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// CompositePass allows to run set of passes until fixed point is reached.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Transforms/Passes.h"
14
15#include "mlir/Pass/Pass.h"
16#include "mlir/Pass/PassManager.h"
17
18namespace mlir {
19#define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
20#include "mlir/Transforms/Passes.h.inc"
21} // namespace mlir
22
23using namespace mlir;
24
25namespace {
26struct CompositeFixedPointPass final
27 : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
28 using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
29
30 CompositeFixedPointPass(
31 std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
32 int maxIterations) {
33 name = std::move(name_);
34 maxIter = maxIterations;
35 populateFunc(dynamicPM);
36
37 llvm::raw_string_ostream os(pipelineStr);
38 dynamicPM.printAsTextualPipeline(os&: os);
39 }
40
41 LogicalResult initializeOptions(
42 StringRef options,
43 function_ref<LogicalResult(const Twine &)> errorHandler) override {
44 if (failed(CompositeFixedPointPassBase::initializeOptions(options,
45 errorHandler)))
46 return failure();
47
48 if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
49 return errorHandler("Failed to parse composite pass pipeline");
50
51 return success();
52 }
53
54 LogicalResult initialize(MLIRContext *context) override {
55 if (maxIter <= 0)
56 return emitError(UnknownLoc::get(context))
57 << "Invalid maxIterations value: " << maxIter << "\n";
58
59 return success();
60 }
61
62 void getDependentDialects(DialectRegistry &registry) const override {
63 dynamicPM.getDependentDialects(dialects&: registry);
64 }
65
66 void runOnOperation() override {
67 auto op = getOperation();
68 OperationFingerPrint fp(op);
69
70 int currentIter = 0;
71 int maxIterVal = maxIter;
72 while (true) {
73 if (failed(runPipeline(dynamicPM, op)))
74 return signalPassFailure();
75
76 if (currentIter++ >= maxIterVal) {
77 op->emitWarning("Composite pass \"" + llvm::Twine(name) +
78 "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
79 " iterations");
80 break;
81 }
82
83 OperationFingerPrint newFp(op);
84 if (newFp == fp)
85 break;
86
87 fp = newFp;
88 }
89 }
90
91protected:
92 llvm::StringRef getName() const override { return name; }
93
94private:
95 OpPassManager dynamicPM;
96};
97} // namespace
98
99std::unique_ptr<Pass> mlir::createCompositeFixedPointPass(
100 std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
101 int maxIterations) {
102
103 return std::make_unique<CompositeFixedPointPass>(args: std::move(name),
104 args&: populateFunc, args&: maxIterations);
105}
106

source code of mlir/lib/Transforms/CompositePass.cpp