1 | //===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H |
10 | #define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H |
11 | |
12 | #include "llvm/ADT/SmallVector.h" |
13 | #include "llvm/CodeGen/ISDOpcodes.h" |
14 | #include "llvm/CodeGen/SelectionDAGNodes.h" |
15 | #include "llvm/IR/InstrTypes.h" |
16 | #include "llvm/Support/BranchProbability.h" |
17 | #include <vector> |
18 | |
19 | namespace llvm { |
20 | |
21 | class BlockFrequencyInfo; |
22 | class ConstantInt; |
23 | class FunctionLoweringInfo; |
24 | class MachineBasicBlock; |
25 | class ProfileSummaryInfo; |
26 | class TargetLowering; |
27 | class TargetMachine; |
28 | |
29 | namespace SwitchCG { |
30 | |
31 | enum CaseClusterKind { |
32 | /// A cluster of adjacent case labels with the same destination, or just one |
33 | /// case. |
34 | CC_Range, |
35 | /// A cluster of cases suitable for jump table lowering. |
36 | CC_JumpTable, |
37 | /// A cluster of cases suitable for bit test lowering. |
38 | CC_BitTests |
39 | }; |
40 | |
41 | /// A cluster of case labels. |
42 | struct CaseCluster { |
43 | CaseClusterKind Kind; |
44 | const ConstantInt *Low, *High; |
45 | union { |
46 | MachineBasicBlock *MBB; |
47 | unsigned JTCasesIndex; |
48 | unsigned BTCasesIndex; |
49 | }; |
50 | BranchProbability Prob; |
51 | |
52 | static CaseCluster range(const ConstantInt *Low, const ConstantInt *High, |
53 | MachineBasicBlock *MBB, BranchProbability Prob) { |
54 | CaseCluster C; |
55 | C.Kind = CC_Range; |
56 | C.Low = Low; |
57 | C.High = High; |
58 | C.MBB = MBB; |
59 | C.Prob = Prob; |
60 | return C; |
61 | } |
62 | |
63 | static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High, |
64 | unsigned JTCasesIndex, BranchProbability Prob) { |
65 | CaseCluster C; |
66 | C.Kind = CC_JumpTable; |
67 | C.Low = Low; |
68 | C.High = High; |
69 | C.JTCasesIndex = JTCasesIndex; |
70 | C.Prob = Prob; |
71 | return C; |
72 | } |
73 | |
74 | static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, |
75 | unsigned BTCasesIndex, BranchProbability Prob) { |
76 | CaseCluster C; |
77 | C.Kind = CC_BitTests; |
78 | C.Low = Low; |
79 | C.High = High; |
80 | C.BTCasesIndex = BTCasesIndex; |
81 | C.Prob = Prob; |
82 | return C; |
83 | } |
84 | }; |
85 | |
86 | using CaseClusterVector = std::vector<CaseCluster>; |
87 | using CaseClusterIt = CaseClusterVector::iterator; |
88 | |
89 | /// Sort Clusters and merge adjacent cases. |
90 | void sortAndRangeify(CaseClusterVector &Clusters); |
91 | |
92 | struct CaseBits { |
93 | uint64_t Mask = 0; |
94 | MachineBasicBlock *BB = nullptr; |
95 | unsigned Bits = 0; |
96 | BranchProbability ; |
97 | |
98 | CaseBits() = default; |
99 | CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits, |
100 | BranchProbability Prob) |
101 | : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {} |
102 | }; |
103 | |
104 | using CaseBitsVector = std::vector<CaseBits>; |
105 | |
106 | /// This structure is used to communicate between SelectionDAGBuilder and |
107 | /// SDISel for the code generation of additional basic blocks needed by |
108 | /// multi-case switch statements. |
109 | struct CaseBlock { |
110 | // For the GISel interface. |
111 | struct PredInfoPair { |
112 | CmpInst::Predicate Pred; |
113 | // Set when no comparison should be emitted. |
114 | bool NoCmp; |
115 | }; |
116 | union { |
117 | // The condition code to use for the case block's setcc node. |
118 | // Besides the integer condition codes, this can also be SETTRUE, in which |
119 | // case no comparison gets emitted. |
120 | ISD::CondCode CC; |
121 | struct PredInfoPair PredInfo; |
122 | }; |
123 | |
124 | // The LHS/MHS/RHS of the comparison to emit. |
125 | // Emit by default LHS op RHS. MHS is used for range comparisons: |
126 | // If MHS is not null: (LHS <= MHS) and (MHS <= RHS). |
127 | const Value *CmpLHS, *CmpMHS, *CmpRHS; |
128 | |
129 | // The block to branch to if the setcc is true/false. |
130 | MachineBasicBlock *TrueBB, *FalseBB; |
131 | |
132 | // The block into which to emit the code for the setcc and branches. |
133 | MachineBasicBlock *ThisBB; |
134 | |
135 | /// The debug location of the instruction this CaseBlock was |
136 | /// produced from. |
137 | SDLoc DL; |
138 | DebugLoc DbgLoc; |
139 | |
140 | // Branch weights. |
141 | BranchProbability TrueProb, FalseProb; |
142 | |
143 | // Constructor for SelectionDAG. |
144 | CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs, |
145 | const Value *cmpmiddle, MachineBasicBlock *truebb, |
146 | MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl, |
147 | BranchProbability trueprob = BranchProbability::getUnknown(), |
148 | BranchProbability falseprob = BranchProbability::getUnknown()) |
149 | : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs), |
150 | TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl), |
151 | TrueProb(trueprob), FalseProb(falseprob) {} |
152 | |
153 | // Constructor for GISel. |
154 | CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs, |
155 | const Value *cmprhs, const Value *cmpmiddle, |
156 | MachineBasicBlock *truebb, MachineBasicBlock *falsebb, |
157 | MachineBasicBlock *me, DebugLoc dl, |
158 | BranchProbability trueprob = BranchProbability::getUnknown(), |
159 | BranchProbability falseprob = BranchProbability::getUnknown()) |
160 | : PredInfo({.Pred: pred, .NoCmp: nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle), |
161 | CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me), |
162 | DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {} |
163 | }; |
164 | |
165 | struct JumpTable { |
166 | /// The virtual register containing the index of the jump table entry |
167 | /// to jump to. |
168 | unsigned Reg; |
169 | /// The JumpTableIndex for this jump table in the function. |
170 | unsigned JTI; |
171 | /// The MBB into which to emit the code for the indirect jump. |
172 | MachineBasicBlock *MBB; |
173 | /// The MBB of the default bb, which is a successor of the range |
174 | /// check MBB. This is when updating PHI nodes in successors. |
175 | MachineBasicBlock *Default; |
176 | |
177 | /// The debug location of the instruction this JumpTable was produced from. |
178 | std::optional<SDLoc> SL; // For SelectionDAG |
179 | |
180 | JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D, |
181 | std::optional<SDLoc> SL) |
182 | : Reg(R), JTI(J), MBB(M), Default(D), SL(SL) {} |
183 | }; |
184 | struct { |
185 | APInt ; |
186 | APInt ; |
187 | const Value *; |
188 | MachineBasicBlock *; |
189 | bool ; |
190 | bool = false; |
191 | |
192 | (APInt F, APInt L, const Value *SV, MachineBasicBlock *H, |
193 | bool E = false) |
194 | : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H), |
195 | Emitted(E) {} |
196 | }; |
197 | using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>; |
198 | |
199 | struct BitTestCase { |
200 | uint64_t Mask; |
201 | MachineBasicBlock *ThisBB; |
202 | MachineBasicBlock *TargetBB; |
203 | BranchProbability ; |
204 | |
205 | BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr, |
206 | BranchProbability Prob) |
207 | : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {} |
208 | }; |
209 | |
210 | using BitTestInfo = SmallVector<BitTestCase, 3>; |
211 | |
212 | struct BitTestBlock { |
213 | APInt First; |
214 | APInt Range; |
215 | const Value *SValue; |
216 | unsigned Reg; |
217 | MVT RegVT; |
218 | bool Emitted; |
219 | bool ContiguousRange; |
220 | MachineBasicBlock *Parent; |
221 | MachineBasicBlock *Default; |
222 | BitTestInfo Cases; |
223 | BranchProbability Prob; |
224 | BranchProbability DefaultProb; |
225 | bool FallthroughUnreachable = false; |
226 | |
227 | BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E, |
228 | bool CR, MachineBasicBlock *P, MachineBasicBlock *D, |
229 | BitTestInfo C, BranchProbability Pr) |
230 | : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg), |
231 | RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D), |
232 | Cases(std::move(C)), Prob(Pr) {} |
233 | }; |
234 | |
235 | /// Return the range of values within a range. |
236 | uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First, |
237 | unsigned Last); |
238 | |
239 | /// Return the number of cases within a range. |
240 | uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases, |
241 | unsigned First, unsigned Last); |
242 | |
243 | struct SwitchWorkListItem { |
244 | MachineBasicBlock *MBB = nullptr; |
245 | CaseClusterIt FirstCluster; |
246 | CaseClusterIt LastCluster; |
247 | const ConstantInt *GE = nullptr; |
248 | const ConstantInt *LT = nullptr; |
249 | BranchProbability DefaultProb; |
250 | }; |
251 | using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>; |
252 | |
253 | class SwitchLowering { |
254 | public: |
255 | SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {} |
256 | |
257 | void init(const TargetLowering &tli, const TargetMachine &tm, |
258 | const DataLayout &dl) { |
259 | TLI = &tli; |
260 | TM = &tm; |
261 | DL = &dl; |
262 | } |
263 | |
264 | /// Vector of CaseBlock structures used to communicate SwitchInst code |
265 | /// generation information. |
266 | std::vector<CaseBlock> SwitchCases; |
267 | |
268 | /// Vector of JumpTable structures used to communicate SwitchInst code |
269 | /// generation information. |
270 | std::vector<JumpTableBlock> JTCases; |
271 | |
272 | /// Vector of BitTestBlock structures used to communicate SwitchInst code |
273 | /// generation information. |
274 | std::vector<BitTestBlock> BitTestCases; |
275 | |
276 | void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, |
277 | std::optional<SDLoc> SL, MachineBasicBlock *DefaultMBB, |
278 | ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI); |
279 | |
280 | bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, |
281 | unsigned Last, const SwitchInst *SI, |
282 | const std::optional<SDLoc> &SL, |
283 | MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster); |
284 | |
285 | void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI); |
286 | |
287 | /// Build a bit test cluster from Clusters[First..Last]. Returns false if it |
288 | /// decides it's not a good idea. |
289 | bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, |
290 | const SwitchInst *SI, CaseCluster &BTCluster); |
291 | |
292 | virtual void addSuccessorWithProb( |
293 | MachineBasicBlock *Src, MachineBasicBlock *Dst, |
294 | BranchProbability Prob = BranchProbability::getUnknown()) = 0; |
295 | |
296 | /// Determine the rank by weight of CC in [First,Last]. If CC has more weight |
297 | /// than each cluster in the range, its rank is 0. |
298 | unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First, |
299 | CaseClusterIt Last); |
300 | |
301 | struct SplitWorkItemInfo { |
302 | CaseClusterIt LastLeft; |
303 | CaseClusterIt FirstRight; |
304 | BranchProbability LeftProb; |
305 | BranchProbability RightProb; |
306 | }; |
307 | /// Compute information to balance the tree based on branch probabilities to |
308 | /// create a near-optimal (in terms of search time given key frequency) binary |
309 | /// search tree. See e.g. Kurt Mehlhorn "Nearly Optimal Binary Search Trees" |
310 | /// (1975). |
311 | SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W); |
312 | virtual ~SwitchLowering() = default; |
313 | |
314 | private: |
315 | const TargetLowering *TLI = nullptr; |
316 | const TargetMachine *TM = nullptr; |
317 | const DataLayout *DL = nullptr; |
318 | FunctionLoweringInfo &FuncInfo; |
319 | }; |
320 | |
321 | } // namespace SwitchCG |
322 | } // namespace llvm |
323 | |
324 | #endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H |
325 | |