1//===-- SpeculateAnalyses.cpp --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/ExecutionEngine/Orc/SpeculateAnalyses.h"
10#include "llvm/ADT/ArrayRef.h"
11#include "llvm/ADT/DenseMap.h"
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/ADT/SmallVector.h"
14#include "llvm/Analysis/BlockFrequencyInfo.h"
15#include "llvm/Analysis/BranchProbabilityInfo.h"
16#include "llvm/Analysis/CFG.h"
17#include "llvm/IR/PassManager.h"
18#include "llvm/Passes/PassBuilder.h"
19#include "llvm/Support/ErrorHandling.h"
20
21#include <algorithm>
22
23namespace {
24using namespace llvm;
25SmallVector<const BasicBlock *, 8> findBBwithCalls(const Function &F,
26 bool IndirectCall = false) {
27 SmallVector<const BasicBlock *, 8> BBs;
28
29 auto findCallInst = [&IndirectCall](const Instruction &I) {
30 if (auto Call = dyn_cast<CallBase>(Val: &I))
31 return Call->isIndirectCall() ? IndirectCall : true;
32 else
33 return false;
34 };
35 for (auto &BB : F)
36 if (findCallInst(*BB.getTerminator()) ||
37 llvm::any_of(Range: BB.instructionsWithoutDebug(), P: findCallInst))
38 BBs.emplace_back(Args: &BB);
39
40 return BBs;
41}
42} // namespace
43
44// Implementations of Queries shouldn't need to lock the resources
45// such as LLVMContext, each argument (function) has a non-shared LLVMContext
46// Plus, if Queries contain states necessary locking scheme should be provided.
47namespace llvm {
48namespace orc {
49
50// Collect direct calls only
51void SpeculateQuery::findCalles(const BasicBlock *BB,
52 DenseSet<StringRef> &CallesNames) {
53 assert(BB != nullptr && "Traversing Null BB to find calls?");
54
55 auto getCalledFunction = [&CallesNames](const CallBase *Call) {
56 auto CalledValue = Call->getCalledOperand()->stripPointerCasts();
57 if (auto DirectCall = dyn_cast<Function>(Val: CalledValue))
58 CallesNames.insert(V: DirectCall->getName());
59 };
60 for (auto &I : BB->instructionsWithoutDebug())
61 if (auto CI = dyn_cast<CallInst>(Val: &I))
62 getCalledFunction(CI);
63
64 if (auto II = dyn_cast<InvokeInst>(Val: BB->getTerminator()))
65 getCalledFunction(II);
66}
67
68bool SpeculateQuery::isStraightLine(const Function &F) {
69 return llvm::all_of(Range: F, P: [](const BasicBlock &BB) {
70 return BB.getSingleSuccessor() != nullptr;
71 });
72}
73
74// BlockFreqQuery Implementations
75
76size_t BlockFreqQuery::numBBToGet(size_t numBB) {
77 // small CFG
78 if (numBB < 4)
79 return numBB;
80 // mid-size CFG
81 else if (numBB < 20)
82 return (numBB / 2);
83 else
84 return (numBB / 2) + (numBB / 4);
85}
86
87BlockFreqQuery::ResultTy BlockFreqQuery::operator()(Function &F) {
88 DenseMap<StringRef, DenseSet<StringRef>> CallerAndCalles;
89 DenseSet<StringRef> Calles;
90 SmallVector<std::pair<const BasicBlock *, uint64_t>, 8> BBFreqs;
91
92 PassBuilder PB;
93 FunctionAnalysisManager FAM;
94 PB.registerFunctionAnalyses(FAM);
95
96 auto IBBs = findBBwithCalls(F);
97
98 if (IBBs.empty())
99 return std::nullopt;
100
101 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(IR&: F);
102
103 for (const auto I : IBBs)
104 BBFreqs.push_back(Elt: {I, BFI.getBlockFreq(BB: I).getFrequency()});
105
106 assert(IBBs.size() == BBFreqs.size() && "BB Count Mismatch");
107
108 llvm::sort(C&: BBFreqs, Comp: [](decltype(BBFreqs)::const_reference BBF,
109 decltype(BBFreqs)::const_reference BBS) {
110 return BBF.second > BBS.second ? true : false;
111 });
112
113 // ignoring number of direct calls in a BB
114 auto Topk = numBBToGet(numBB: BBFreqs.size());
115
116 for (size_t i = 0; i < Topk; i++)
117 findCalles(BB: BBFreqs[i].first, CallesNames&: Calles);
118
119 assert(!Calles.empty() && "Running Analysis on Function with no calls?");
120
121 CallerAndCalles.insert(KV: {F.getName(), std::move(Calles)});
122
123 return CallerAndCalles;
124}
125
126// SequenceBBQuery Implementation
127std::size_t SequenceBBQuery::getHottestBlocks(std::size_t TotalBlocks) {
128 if (TotalBlocks == 1)
129 return TotalBlocks;
130 return TotalBlocks / 2;
131}
132
133// FIXME : find good implementation.
134SequenceBBQuery::BlockListTy
135SequenceBBQuery::rearrangeBB(const Function &F, const BlockListTy &BBList) {
136 BlockListTy RearrangedBBSet;
137
138 for (auto &Block : F)
139 if (llvm::is_contained(Range: BBList, Element: &Block))
140 RearrangedBBSet.push_back(Elt: &Block);
141
142 assert(RearrangedBBSet.size() == BBList.size() &&
143 "BasicBlock missing while rearranging?");
144 return RearrangedBBSet;
145}
146
147void SequenceBBQuery::traverseToEntryBlock(const BasicBlock *AtBB,
148 const BlockListTy &CallerBlocks,
149 const BackEdgesInfoTy &BackEdgesInfo,
150 const BranchProbabilityInfo *BPI,
151 VisitedBlocksInfoTy &VisitedBlocks) {
152 auto Itr = VisitedBlocks.find(Val: AtBB);
153 if (Itr != VisitedBlocks.end()) { // already visited.
154 if (!Itr->second.Upward)
155 return;
156 Itr->second.Upward = false;
157 } else {
158 // Create hint for newly discoverd blocks.
159 WalkDirection BlockHint;
160 BlockHint.Upward = false;
161 // FIXME: Expensive Check
162 if (llvm::is_contained(Range: CallerBlocks, Element: AtBB))
163 BlockHint.CallerBlock = true;
164 VisitedBlocks.insert(KV: std::make_pair(x&: AtBB, y&: BlockHint));
165 }
166
167 const_pred_iterator PIt = pred_begin(BB: AtBB), EIt = pred_end(BB: AtBB);
168 // Move this check to top, when we have code setup to launch speculative
169 // compiles for function in entry BB, this triggers the speculative compiles
170 // before running the program.
171 if (PIt == EIt) // No Preds.
172 return;
173
174 DenseSet<const BasicBlock *> PredSkipNodes;
175
176 // Since we are checking for predecessor's backedges, this Block
177 // occurs in second position.
178 for (auto &I : BackEdgesInfo)
179 if (I.second == AtBB)
180 PredSkipNodes.insert(V: I.first);
181
182 // Skip predecessors which source of back-edges.
183 for (; PIt != EIt; ++PIt)
184 // checking EdgeHotness is cheaper
185 if (BPI->isEdgeHot(Src: *PIt, Dst: AtBB) && !PredSkipNodes.count(V: *PIt))
186 traverseToEntryBlock(AtBB: *PIt, CallerBlocks, BackEdgesInfo, BPI,
187 VisitedBlocks);
188}
189
190void SequenceBBQuery::traverseToExitBlock(const BasicBlock *AtBB,
191 const BlockListTy &CallerBlocks,
192 const BackEdgesInfoTy &BackEdgesInfo,
193 const BranchProbabilityInfo *BPI,
194 VisitedBlocksInfoTy &VisitedBlocks) {
195 auto Itr = VisitedBlocks.find(Val: AtBB);
196 if (Itr != VisitedBlocks.end()) { // already visited.
197 if (!Itr->second.Downward)
198 return;
199 Itr->second.Downward = false;
200 } else {
201 // Create hint for newly discoverd blocks.
202 WalkDirection BlockHint;
203 BlockHint.Downward = false;
204 // FIXME: Expensive Check
205 if (llvm::is_contained(Range: CallerBlocks, Element: AtBB))
206 BlockHint.CallerBlock = true;
207 VisitedBlocks.insert(KV: std::make_pair(x&: AtBB, y&: BlockHint));
208 }
209
210 const_succ_iterator PIt = succ_begin(BB: AtBB), EIt = succ_end(BB: AtBB);
211 if (PIt == EIt) // No succs.
212 return;
213
214 // If there are hot edges, then compute SuccSkipNodes.
215 DenseSet<const BasicBlock *> SuccSkipNodes;
216
217 // Since we are checking for successor's backedges, this Block
218 // occurs in first position.
219 for (auto &I : BackEdgesInfo)
220 if (I.first == AtBB)
221 SuccSkipNodes.insert(V: I.second);
222
223 for (; PIt != EIt; ++PIt)
224 if (BPI->isEdgeHot(Src: AtBB, Dst: *PIt) && !SuccSkipNodes.count(V: *PIt))
225 traverseToExitBlock(AtBB: *PIt, CallerBlocks, BackEdgesInfo, BPI,
226 VisitedBlocks);
227}
228
229// Get Block frequencies for blocks and take most frequently executed block,
230// walk towards the entry block from those blocks and discover the basic blocks
231// with call.
232SequenceBBQuery::BlockListTy
233SequenceBBQuery::queryCFG(Function &F, const BlockListTy &CallerBlocks) {
234
235 BlockFreqInfoTy BBFreqs;
236 VisitedBlocksInfoTy VisitedBlocks;
237 BackEdgesInfoTy BackEdgesInfo;
238
239 PassBuilder PB;
240 FunctionAnalysisManager FAM;
241 PB.registerFunctionAnalyses(FAM);
242
243 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(IR&: F);
244
245 llvm::FindFunctionBackedges(F, Result&: BackEdgesInfo);
246
247 for (const auto I : CallerBlocks)
248 BBFreqs.push_back(Elt: {I, BFI.getBlockFreq(BB: I).getFrequency()});
249
250 llvm::sort(C&: BBFreqs, Comp: [](decltype(BBFreqs)::const_reference Bbf,
251 decltype(BBFreqs)::const_reference Bbs) {
252 return Bbf.second > Bbs.second;
253 });
254
255 ArrayRef<std::pair<const BasicBlock *, uint64_t>> HotBlocksRef(BBFreqs);
256 HotBlocksRef =
257 HotBlocksRef.drop_back(N: BBFreqs.size() - getHottestBlocks(TotalBlocks: BBFreqs.size()));
258
259 BranchProbabilityInfo *BPI =
260 FAM.getCachedResult<BranchProbabilityAnalysis>(IR&: F);
261
262 // visit NHotBlocks,
263 // traverse upwards to entry
264 // traverse downwards to end.
265
266 for (auto I : HotBlocksRef) {
267 traverseToEntryBlock(AtBB: I.first, CallerBlocks, BackEdgesInfo, BPI,
268 VisitedBlocks);
269 traverseToExitBlock(AtBB: I.first, CallerBlocks, BackEdgesInfo, BPI,
270 VisitedBlocks);
271 }
272
273 BlockListTy MinCallerBlocks;
274 for (auto &I : VisitedBlocks)
275 if (I.second.CallerBlock)
276 MinCallerBlocks.push_back(Elt: std::move(I.first));
277
278 return rearrangeBB(F, BBList: MinCallerBlocks);
279}
280
281SpeculateQuery::ResultTy SequenceBBQuery::operator()(Function &F) {
282 // reduce the number of lists!
283 DenseMap<StringRef, DenseSet<StringRef>> CallerAndCalles;
284 DenseSet<StringRef> Calles;
285 BlockListTy SequencedBlocks;
286 BlockListTy CallerBlocks;
287
288 CallerBlocks = findBBwithCalls(F);
289 if (CallerBlocks.empty())
290 return std::nullopt;
291
292 if (isStraightLine(F))
293 SequencedBlocks = rearrangeBB(F, BBList: CallerBlocks);
294 else
295 SequencedBlocks = queryCFG(F, CallerBlocks);
296
297 for (const auto *BB : SequencedBlocks)
298 findCalles(BB, CallesNames&: Calles);
299
300 CallerAndCalles.insert(KV: {F.getName(), std::move(Calles)});
301 return CallerAndCalles;
302}
303
304} // namespace orc
305} // namespace llvm
306

source code of llvm/lib/ExecutionEngine/Orc/SpeculateAnalyses.cpp