1//===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// In this example we will use an IR transform to optimize a module as it
10// passes through LLJIT's IRTransformLayer.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ExecutionEngine/Orc/LLJIT.h"
15#include "llvm/IR/LegacyPassManager.h"
16#include "llvm/Pass.h"
17#include "llvm/Support/InitLLVM.h"
18#include "llvm/Support/TargetSelect.h"
19#include "llvm/Support/raw_ostream.h"
20#include "llvm/Transforms/IPO.h"
21#include "llvm/Transforms/Scalar.h"
22
23#include "../ExampleModules.h"
24
25using namespace llvm;
26using namespace llvm::orc;
27
28ExitOnError ExitOnErr;
29
30// Example IR module.
31//
32// This IR contains a recursive definition of the factorial function:
33//
34// fac(n) | n == 0 = 1
35// | otherwise = n * fac(n - 1)
36//
37// It also contains an entry function which calls the factorial function with
38// an input value of 5.
39//
40// We expect the IR optimization transform that we build below to transform
41// this into a non-recursive factorial function and an entry function that
42// returns a constant value of 5!, or 120.
43
44const llvm::StringRef MainMod =
45 R"(
46
47 define i32 @fac(i32 %n) {
48 entry:
49 %tobool = icmp eq i32 %n, 0
50 br i1 %tobool, label %return, label %if.then
51
52 if.then: ; preds = %entry
53 %arg = add nsw i32 %n, -1
54 %call_result = call i32 @fac(i32 %arg)
55 %result = mul nsw i32 %n, %call_result
56 br label %return
57
58 return: ; preds = %entry, %if.then
59 %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
60 ret i32 %final_result
61 }
62
63 define i32 @entry() {
64 entry:
65 %result = call i32 @fac(i32 5)
66 ret i32 %result
67 }
68
69)";
70
71// A function object that creates a simple pass pipeline to apply to each
72// module as it passes through the IRTransformLayer.
73class MyOptimizationTransform {
74public:
75 MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) {
76 PM->add(P: createTailCallEliminationPass());
77 PM->add(P: createCFGSimplificationPass());
78 }
79
80 Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM,
81 MaterializationResponsibility &R) {
82 TSM.withModuleDo(F: [this](Module &M) {
83 dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n";
84 PM->run(M);
85 dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n";
86 });
87 return std::move(TSM);
88 }
89
90private:
91 std::unique_ptr<legacy::PassManager> PM;
92};
93
94int main(int argc, char *argv[]) {
95 // Initialize LLVM.
96 InitLLVM X(argc, argv);
97
98 InitializeNativeTarget();
99 InitializeNativeTargetAsmPrinter();
100
101 ExitOnErr.setBanner(std::string(argv[0]) + ": ");
102
103 // (1) Create LLJIT instance.
104 auto J = ExitOnErr(LLJITBuilder().create());
105
106 // (2) Install transform to optimize modules when they're materialized.
107 J->getIRTransformLayer().setTransform(MyOptimizationTransform());
108
109 // (3) Add modules.
110 ExitOnErr(J->addIRModule(TSM: ExitOnErr(parseExampleModule(Source: MainMod, Name: "MainMod"))));
111
112 // (4) Look up the JIT'd function and call it.
113 auto EntryAddr = ExitOnErr(J->lookup(UnmangledName: "entry"));
114 auto *Entry = EntryAddr.toPtr<int()>();
115
116 int Result = Entry();
117 outs() << "--- Result ---\n"
118 << "entry() = " << Result << "\n";
119
120 return 0;
121}
122

source code of llvm/examples/OrcV2Examples/LLJITWithOptimizingIRTransform/LLJITWithOptimizingIRTransform.cpp