1//===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
10#define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
11
12#include "llvm/ADT/SmallVector.h"
13#include "llvm/CodeGen/ISDOpcodes.h"
14#include "llvm/CodeGen/SelectionDAGNodes.h"
15#include "llvm/IR/InstrTypes.h"
16#include "llvm/Support/BranchProbability.h"
17#include <vector>
18
19namespace llvm {
20
21class BlockFrequencyInfo;
22class ConstantInt;
23class FunctionLoweringInfo;
24class MachineBasicBlock;
25class ProfileSummaryInfo;
26class TargetLowering;
27class TargetMachine;
28
29namespace SwitchCG {
30
31enum CaseClusterKind {
32 /// A cluster of adjacent case labels with the same destination, or just one
33 /// case.
34 CC_Range,
35 /// A cluster of cases suitable for jump table lowering.
36 CC_JumpTable,
37 /// A cluster of cases suitable for bit test lowering.
38 CC_BitTests
39};
40
41/// A cluster of case labels.
42struct CaseCluster {
43 CaseClusterKind Kind;
44 const ConstantInt *Low, *High;
45 union {
46 MachineBasicBlock *MBB;
47 unsigned JTCasesIndex;
48 unsigned BTCasesIndex;
49 };
50 BranchProbability Prob;
51
52 static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
53 MachineBasicBlock *MBB, BranchProbability Prob) {
54 CaseCluster C;
55 C.Kind = CC_Range;
56 C.Low = Low;
57 C.High = High;
58 C.MBB = MBB;
59 C.Prob = Prob;
60 return C;
61 }
62
63 static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
64 unsigned JTCasesIndex, BranchProbability Prob) {
65 CaseCluster C;
66 C.Kind = CC_JumpTable;
67 C.Low = Low;
68 C.High = High;
69 C.JTCasesIndex = JTCasesIndex;
70 C.Prob = Prob;
71 return C;
72 }
73
74 static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
75 unsigned BTCasesIndex, BranchProbability Prob) {
76 CaseCluster C;
77 C.Kind = CC_BitTests;
78 C.Low = Low;
79 C.High = High;
80 C.BTCasesIndex = BTCasesIndex;
81 C.Prob = Prob;
82 return C;
83 }
84};
85
86using CaseClusterVector = std::vector<CaseCluster>;
87using CaseClusterIt = CaseClusterVector::iterator;
88
89/// Sort Clusters and merge adjacent cases.
90void sortAndRangeify(CaseClusterVector &Clusters);
91
92struct CaseBits {
93 uint64_t Mask = 0;
94 MachineBasicBlock *BB = nullptr;
95 unsigned Bits = 0;
96 BranchProbability ExtraProb;
97
98 CaseBits() = default;
99 CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
100 BranchProbability Prob)
101 : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
102};
103
104using CaseBitsVector = std::vector<CaseBits>;
105
106/// This structure is used to communicate between SelectionDAGBuilder and
107/// SDISel for the code generation of additional basic blocks needed by
108/// multi-case switch statements.
109struct CaseBlock {
110 // For the GISel interface.
111 struct PredInfoPair {
112 CmpInst::Predicate Pred;
113 // Set when no comparison should be emitted.
114 bool NoCmp;
115 };
116 union {
117 // The condition code to use for the case block's setcc node.
118 // Besides the integer condition codes, this can also be SETTRUE, in which
119 // case no comparison gets emitted.
120 ISD::CondCode CC;
121 struct PredInfoPair PredInfo;
122 };
123
124 // The LHS/MHS/RHS of the comparison to emit.
125 // Emit by default LHS op RHS. MHS is used for range comparisons:
126 // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
127 const Value *CmpLHS, *CmpMHS, *CmpRHS;
128
129 // The block to branch to if the setcc is true/false.
130 MachineBasicBlock *TrueBB, *FalseBB;
131
132 // The block into which to emit the code for the setcc and branches.
133 MachineBasicBlock *ThisBB;
134
135 /// The debug location of the instruction this CaseBlock was
136 /// produced from.
137 SDLoc DL;
138 DebugLoc DbgLoc;
139
140 // Branch weights.
141 BranchProbability TrueProb, FalseProb;
142
143 // Constructor for SelectionDAG.
144 CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
145 const Value *cmpmiddle, MachineBasicBlock *truebb,
146 MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
147 BranchProbability trueprob = BranchProbability::getUnknown(),
148 BranchProbability falseprob = BranchProbability::getUnknown())
149 : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
150 TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
151 TrueProb(trueprob), FalseProb(falseprob) {}
152
153 // Constructor for GISel.
154 CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
155 const Value *cmprhs, const Value *cmpmiddle,
156 MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
157 MachineBasicBlock *me, DebugLoc dl,
158 BranchProbability trueprob = BranchProbability::getUnknown(),
159 BranchProbability falseprob = BranchProbability::getUnknown())
160 : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
161 CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
162 DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
163};
164
165struct JumpTable {
166 /// The virtual register containing the index of the jump table entry
167 /// to jump to.
168 unsigned Reg;
169 /// The JumpTableIndex for this jump table in the function.
170 unsigned JTI;
171 /// The MBB into which to emit the code for the indirect jump.
172 MachineBasicBlock *MBB;
173 /// The MBB of the default bb, which is a successor of the range
174 /// check MBB. This is when updating PHI nodes in successors.
175 MachineBasicBlock *Default;
176
177 JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
178 : Reg(R), JTI(J), MBB(M), Default(D) {}
179};
180struct JumpTableHeader {
181 APInt First;
182 APInt Last;
183 const Value *SValue;
184 MachineBasicBlock *HeaderBB;
185 bool Emitted;
186 bool OmitRangeCheck;
187
188 JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
189 bool E = false)
190 : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
191 Emitted(E), OmitRangeCheck(false) {}
192};
193using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
194
195struct BitTestCase {
196 uint64_t Mask;
197 MachineBasicBlock *ThisBB;
198 MachineBasicBlock *TargetBB;
199 BranchProbability ExtraProb;
200
201 BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
202 BranchProbability Prob)
203 : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
204};
205
206using BitTestInfo = SmallVector<BitTestCase, 3>;
207
208struct BitTestBlock {
209 APInt First;
210 APInt Range;
211 const Value *SValue;
212 unsigned Reg;
213 MVT RegVT;
214 bool Emitted;
215 bool ContiguousRange;
216 MachineBasicBlock *Parent;
217 MachineBasicBlock *Default;
218 BitTestInfo Cases;
219 BranchProbability Prob;
220 BranchProbability DefaultProb;
221 bool OmitRangeCheck;
222
223 BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
224 bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
225 BitTestInfo C, BranchProbability Pr)
226 : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
227 RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
228 Cases(std::move(C)), Prob(Pr), OmitRangeCheck(false) {}
229};
230
231/// Return the range of values within a range.
232uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
233 unsigned Last);
234
235/// Return the number of cases within a range.
236uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
237 unsigned First, unsigned Last);
238
239struct SwitchWorkListItem {
240 MachineBasicBlock *MBB;
241 CaseClusterIt FirstCluster;
242 CaseClusterIt LastCluster;
243 const ConstantInt *GE;
244 const ConstantInt *LT;
245 BranchProbability DefaultProb;
246};
247using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
248
249class SwitchLowering {
250public:
251 SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
252
253 void init(const TargetLowering &tli, const TargetMachine &tm,
254 const DataLayout &dl) {
255 TLI = &tli;
256 TM = &tm;
257 DL = &dl;
258 }
259
260 /// Vector of CaseBlock structures used to communicate SwitchInst code
261 /// generation information.
262 std::vector<CaseBlock> SwitchCases;
263
264 /// Vector of JumpTable structures used to communicate SwitchInst code
265 /// generation information.
266 std::vector<JumpTableBlock> JTCases;
267
268 /// Vector of BitTestBlock structures used to communicate SwitchInst code
269 /// generation information.
270 std::vector<BitTestBlock> BitTestCases;
271
272 void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
273 MachineBasicBlock *DefaultMBB,
274 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
275
276 bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
277 unsigned Last, const SwitchInst *SI,
278 MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
279
280
281 void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
282
283 /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
284 /// decides it's not a good idea.
285 bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
286 const SwitchInst *SI, CaseCluster &BTCluster);
287
288 virtual void addSuccessorWithProb(
289 MachineBasicBlock *Src, MachineBasicBlock *Dst,
290 BranchProbability Prob = BranchProbability::getUnknown()) = 0;
291
292 virtual ~SwitchLowering() = default;
293
294private:
295 const TargetLowering *TLI;
296 const TargetMachine *TM;
297 const DataLayout *DL;
298 FunctionLoweringInfo &FuncInfo;
299};
300
301} // namespace SwitchCG
302} // namespace llvm
303
304#endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
305