1 | //===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // In this example we will use an IR transform to optimize a module as it |
10 | // passes through LLJIT's IRTransformLayer. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/ExecutionEngine/Orc/LLJIT.h" |
15 | #include "llvm/IR/LegacyPassManager.h" |
16 | #include "llvm/Pass.h" |
17 | #include "llvm/Support/InitLLVM.h" |
18 | #include "llvm/Support/TargetSelect.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | #include "llvm/Transforms/IPO.h" |
21 | #include "llvm/Transforms/Scalar.h" |
22 | |
23 | #include "../ExampleModules.h" |
24 | |
25 | using namespace llvm; |
26 | using namespace llvm::orc; |
27 | |
28 | ExitOnError ExitOnErr; |
29 | |
30 | // Example IR module. |
31 | // |
32 | // This IR contains a recursive definition of the factorial function: |
33 | // |
34 | // fac(n) | n == 0 = 1 |
35 | // | otherwise = n * fac(n - 1) |
36 | // |
37 | // It also contains an entry function which calls the factorial function with |
38 | // an input value of 5. |
39 | // |
40 | // We expect the IR optimization transform that we build below to transform |
41 | // this into a non-recursive factorial function and an entry function that |
42 | // returns a constant value of 5!, or 120. |
43 | |
44 | const llvm::StringRef MainMod = |
45 | R"( |
46 | |
47 | define i32 @fac(i32 %n) { |
48 | entry: |
49 | %tobool = icmp eq i32 %n, 0 |
50 | br i1 %tobool, label %return, label %if.then |
51 | |
52 | if.then: ; preds = %entry |
53 | %arg = add nsw i32 %n, -1 |
54 | %call_result = call i32 @fac(i32 %arg) |
55 | %result = mul nsw i32 %n, %call_result |
56 | br label %return |
57 | |
58 | return: ; preds = %entry, %if.then |
59 | %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ] |
60 | ret i32 %final_result |
61 | } |
62 | |
63 | define i32 @entry() { |
64 | entry: |
65 | %result = call i32 @fac(i32 5) |
66 | ret i32 %result |
67 | } |
68 | |
69 | )" ; |
70 | |
71 | // A function object that creates a simple pass pipeline to apply to each |
72 | // module as it passes through the IRTransformLayer. |
73 | class MyOptimizationTransform { |
74 | public: |
75 | MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) { |
76 | PM->add(P: createTailCallEliminationPass()); |
77 | PM->add(P: createCFGSimplificationPass()); |
78 | } |
79 | |
80 | Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM, |
81 | MaterializationResponsibility &R) { |
82 | TSM.withModuleDo(F: [this](Module &M) { |
83 | dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n" ; |
84 | PM->run(M); |
85 | dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n" ; |
86 | }); |
87 | return std::move(TSM); |
88 | } |
89 | |
90 | private: |
91 | std::unique_ptr<legacy::PassManager> PM; |
92 | }; |
93 | |
94 | int main(int argc, char *argv[]) { |
95 | // Initialize LLVM. |
96 | InitLLVM X(argc, argv); |
97 | |
98 | InitializeNativeTarget(); |
99 | InitializeNativeTargetAsmPrinter(); |
100 | |
101 | ExitOnErr.setBanner(std::string(argv[0]) + ": " ); |
102 | |
103 | // (1) Create LLJIT instance. |
104 | auto J = ExitOnErr(LLJITBuilder().create()); |
105 | |
106 | // (2) Install transform to optimize modules when they're materialized. |
107 | J->getIRTransformLayer().setTransform(MyOptimizationTransform()); |
108 | |
109 | // (3) Add modules. |
110 | ExitOnErr(J->addIRModule(TSM: ExitOnErr(parseExampleModule(Source: MainMod, Name: "MainMod" )))); |
111 | |
112 | // (4) Look up the JIT'd function and call it. |
113 | auto EntryAddr = ExitOnErr(J->lookup(UnmangledName: "entry" )); |
114 | auto *Entry = EntryAddr.toPtr<int()>(); |
115 | |
116 | int Result = Entry(); |
117 | outs() << "--- Result ---\n" |
118 | << "entry() = " << Result << "\n" ; |
119 | |
120 | return 0; |
121 | } |
122 | |