1 | //===- RDFRegisters.h -------------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef LLVM_CODEGEN_RDFREGISTERS_H |
10 | #define LLVM_CODEGEN_RDFREGISTERS_H |
11 | |
12 | #include "llvm/ADT/BitVector.h" |
13 | #include "llvm/ADT/STLExtras.h" |
14 | #include "llvm/ADT/iterator_range.h" |
15 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
16 | #include "llvm/MC/LaneBitmask.h" |
17 | #include "llvm/MC/MCRegister.h" |
18 | #include <cassert> |
19 | #include <cstdint> |
20 | #include <map> |
21 | #include <set> |
22 | #include <vector> |
23 | |
24 | namespace llvm { |
25 | |
26 | class MachineFunction; |
27 | class raw_ostream; |
28 | |
29 | namespace rdf { |
30 | struct RegisterAggr; |
31 | |
32 | using RegisterId = uint32_t; |
33 | |
34 | template <typename T> |
35 | bool disjoint(const std::set<T> &A, const std::set<T> &B) { |
36 | auto ItA = A.begin(), EndA = A.end(); |
37 | auto ItB = B.begin(), EndB = B.end(); |
38 | while (ItA != EndA && ItB != EndB) { |
39 | if (*ItA < *ItB) |
40 | ++ItA; |
41 | else if (*ItB < *ItA) |
42 | ++ItB; |
43 | else |
44 | return false; |
45 | } |
46 | return true; |
47 | } |
48 | |
49 | // Template class for a map translating uint32_t into arbitrary types. |
50 | // The map will act like an indexed set: upon insertion of a new object, |
51 | // it will automatically assign a new index to it. Index of 0 is treated |
52 | // as invalid and is never allocated. |
53 | template <typename T, unsigned N = 32> struct IndexedSet { |
54 | IndexedSet() { Map.reserve(N); } |
55 | |
56 | T get(uint32_t Idx) const { |
57 | // Index Idx corresponds to Map[Idx-1]. |
58 | assert(Idx != 0 && !Map.empty() && Idx - 1 < Map.size()); |
59 | return Map[Idx - 1]; |
60 | } |
61 | |
62 | uint32_t insert(T Val) { |
63 | // Linear search. |
64 | auto F = llvm::find(Map, Val); |
65 | if (F != Map.end()) |
66 | return F - Map.begin() + 1; |
67 | Map.push_back(Val); |
68 | return Map.size(); // Return actual_index + 1. |
69 | } |
70 | |
71 | uint32_t find(T Val) const { |
72 | auto F = llvm::find(Map, Val); |
73 | assert(F != Map.end()); |
74 | return F - Map.begin() + 1; |
75 | } |
76 | |
77 | uint32_t size() const { return Map.size(); } |
78 | |
79 | using const_iterator = typename std::vector<T>::const_iterator; |
80 | |
81 | const_iterator begin() const { return Map.begin(); } |
82 | const_iterator end() const { return Map.end(); } |
83 | |
84 | private: |
85 | std::vector<T> Map; |
86 | }; |
87 | |
88 | struct RegisterRef { |
89 | RegisterId Reg = 0; |
90 | LaneBitmask Mask = LaneBitmask::getNone(); // Only for registers. |
91 | |
92 | constexpr RegisterRef() = default; |
93 | constexpr explicit RegisterRef(RegisterId R, |
94 | LaneBitmask M = LaneBitmask::getAll()) |
95 | : Reg(R), Mask(isRegId(Id: R) && R != 0 ? M : LaneBitmask::getNone()) {} |
96 | |
97 | // Classify null register as a "register". |
98 | constexpr bool isReg() const { return Reg == 0 || isRegId(Id: Reg); } |
99 | constexpr bool isUnit() const { return isUnitId(Id: Reg); } |
100 | constexpr bool isMask() const { return isMaskId(Id: Reg); } |
101 | |
102 | constexpr unsigned idx() const { return toIdx(Id: Reg); } |
103 | |
104 | constexpr operator bool() const { |
105 | return !isReg() || (Reg != 0 && Mask.any()); |
106 | } |
107 | |
108 | size_t hash() const { |
109 | return std::hash<RegisterId>{}(Reg) ^ |
110 | std::hash<LaneBitmask::Type>{}(Mask.getAsInteger()); |
111 | } |
112 | |
113 | static constexpr bool isRegId(unsigned Id) { |
114 | return Register::isPhysicalRegister(Reg: Id); |
115 | } |
116 | static constexpr bool isUnitId(unsigned Id) { |
117 | return Register::isVirtualRegister(Reg: Id); |
118 | } |
119 | static constexpr bool isMaskId(unsigned Id) { |
120 | return Register::isStackSlot(Reg: Id); |
121 | } |
122 | |
123 | static constexpr RegisterId toUnitId(unsigned Idx) { |
124 | return Idx | MCRegister::VirtualRegFlag; |
125 | } |
126 | |
127 | static constexpr unsigned toIdx(RegisterId Id) { |
128 | // Not using virtReg2Index or stackSlot2Index, because they are |
129 | // not constexpr. |
130 | if (isUnitId(Id)) |
131 | return Id & ~MCRegister::VirtualRegFlag; |
132 | // RegId and MaskId are unchanged. |
133 | return Id; |
134 | } |
135 | |
136 | bool operator<(RegisterRef) const = delete; |
137 | bool operator==(RegisterRef) const = delete; |
138 | bool operator!=(RegisterRef) const = delete; |
139 | }; |
140 | |
141 | struct PhysicalRegisterInfo { |
142 | PhysicalRegisterInfo(const TargetRegisterInfo &tri, |
143 | const MachineFunction &mf); |
144 | |
145 | RegisterId getRegMaskId(const uint32_t *RM) const { |
146 | return Register::index2StackSlot(FI: RegMasks.find(Val: RM)); |
147 | } |
148 | |
149 | const uint32_t *getRegMaskBits(RegisterId R) const { |
150 | return RegMasks.get(Idx: Register::stackSlot2Index(Reg: R)); |
151 | } |
152 | |
153 | bool alias(RegisterRef RA, RegisterRef RB) const; |
154 | |
155 | // Returns the set of aliased physical registers. |
156 | std::set<RegisterId> getAliasSet(RegisterId Reg) const; |
157 | |
158 | RegisterRef getRefForUnit(uint32_t U) const { |
159 | return RegisterRef(UnitInfos[U].Reg, UnitInfos[U].Mask); |
160 | } |
161 | |
162 | const BitVector &getMaskUnits(RegisterId MaskId) const { |
163 | return MaskInfos[Register::stackSlot2Index(Reg: MaskId)].Units; |
164 | } |
165 | |
166 | std::set<RegisterId> getUnits(RegisterRef RR) const; |
167 | |
168 | const BitVector &getUnitAliases(uint32_t U) const { |
169 | return AliasInfos[U].Regs; |
170 | } |
171 | |
172 | RegisterRef mapTo(RegisterRef RR, unsigned R) const; |
173 | const TargetRegisterInfo &getTRI() const { return TRI; } |
174 | |
175 | bool equal_to(RegisterRef A, RegisterRef B) const; |
176 | bool less(RegisterRef A, RegisterRef B) const; |
177 | |
178 | void print(raw_ostream &OS, RegisterRef A) const; |
179 | void print(raw_ostream &OS, const RegisterAggr &A) const; |
180 | |
181 | private: |
182 | struct RegInfo { |
183 | const TargetRegisterClass *RegClass = nullptr; |
184 | }; |
185 | struct UnitInfo { |
186 | RegisterId Reg = 0; |
187 | LaneBitmask Mask; |
188 | }; |
189 | struct MaskInfo { |
190 | BitVector Units; |
191 | }; |
192 | struct AliasInfo { |
193 | BitVector Regs; |
194 | }; |
195 | |
196 | const TargetRegisterInfo &TRI; |
197 | IndexedSet<const uint32_t *> RegMasks; |
198 | std::vector<RegInfo> RegInfos; |
199 | std::vector<UnitInfo> UnitInfos; |
200 | std::vector<MaskInfo> MaskInfos; |
201 | std::vector<AliasInfo> AliasInfos; |
202 | }; |
203 | |
204 | struct RegisterAggr { |
205 | RegisterAggr(const PhysicalRegisterInfo &pri) |
206 | : Units(pri.getTRI().getNumRegUnits()), PRI(pri) {} |
207 | RegisterAggr(const RegisterAggr &RG) = default; |
208 | |
209 | unsigned size() const { return Units.count(); } |
210 | bool empty() const { return Units.none(); } |
211 | bool hasAliasOf(RegisterRef RR) const; |
212 | bool hasCoverOf(RegisterRef RR) const; |
213 | |
214 | const PhysicalRegisterInfo &getPRI() const { return PRI; } |
215 | |
216 | bool operator==(const RegisterAggr &A) const { |
217 | return DenseMapInfo<BitVector>::isEqual(LHS: Units, RHS: A.Units); |
218 | } |
219 | |
220 | static bool isCoverOf(RegisterRef RA, RegisterRef RB, |
221 | const PhysicalRegisterInfo &PRI) { |
222 | return RegisterAggr(PRI).insert(RR: RA).hasCoverOf(RR: RB); |
223 | } |
224 | |
225 | RegisterAggr &insert(RegisterRef RR); |
226 | RegisterAggr &insert(const RegisterAggr &RG); |
227 | RegisterAggr &intersect(RegisterRef RR); |
228 | RegisterAggr &intersect(const RegisterAggr &RG); |
229 | RegisterAggr &clear(RegisterRef RR); |
230 | RegisterAggr &clear(const RegisterAggr &RG); |
231 | |
232 | RegisterRef intersectWith(RegisterRef RR) const; |
233 | RegisterRef clearIn(RegisterRef RR) const; |
234 | RegisterRef makeRegRef() const; |
235 | |
236 | size_t hash() const { return DenseMapInfo<BitVector>::getHashValue(V: Units); } |
237 | |
238 | struct ref_iterator { |
239 | using MapType = std::map<RegisterId, LaneBitmask>; |
240 | |
241 | private: |
242 | MapType Masks; |
243 | MapType::iterator Pos; |
244 | unsigned Index; |
245 | const RegisterAggr *Owner; |
246 | |
247 | public: |
248 | ref_iterator(const RegisterAggr &RG, bool End); |
249 | |
250 | RegisterRef operator*() const { |
251 | return RegisterRef(Pos->first, Pos->second); |
252 | } |
253 | |
254 | ref_iterator &operator++() { |
255 | ++Pos; |
256 | ++Index; |
257 | return *this; |
258 | } |
259 | |
260 | bool operator==(const ref_iterator &I) const { |
261 | assert(Owner == I.Owner); |
262 | (void)Owner; |
263 | return Index == I.Index; |
264 | } |
265 | |
266 | bool operator!=(const ref_iterator &I) const { return !(*this == I); } |
267 | }; |
268 | |
269 | ref_iterator ref_begin() const { return ref_iterator(*this, false); } |
270 | ref_iterator ref_end() const { return ref_iterator(*this, true); } |
271 | |
272 | using unit_iterator = typename BitVector::const_set_bits_iterator; |
273 | unit_iterator unit_begin() const { return Units.set_bits_begin(); } |
274 | unit_iterator unit_end() const { return Units.set_bits_end(); } |
275 | |
276 | iterator_range<ref_iterator> refs() const { |
277 | return make_range(x: ref_begin(), y: ref_end()); |
278 | } |
279 | iterator_range<unit_iterator> units() const { |
280 | return make_range(x: unit_begin(), y: unit_end()); |
281 | } |
282 | |
283 | private: |
284 | BitVector Units; |
285 | const PhysicalRegisterInfo &PRI; |
286 | }; |
287 | |
288 | // This is really a std::map, except that it provides a non-trivial |
289 | // default constructor to the element accessed via []. |
290 | template <typename KeyType> struct RegisterAggrMap { |
291 | RegisterAggrMap(const PhysicalRegisterInfo &pri) : Empty(pri) {} |
292 | |
293 | RegisterAggr &operator[](KeyType Key) { |
294 | return Map.emplace(Key, Empty).first->second; |
295 | } |
296 | |
297 | auto begin() { return Map.begin(); } |
298 | auto end() { return Map.end(); } |
299 | auto begin() const { return Map.begin(); } |
300 | auto end() const { return Map.end(); } |
301 | auto find(const KeyType &Key) const { return Map.find(Key); } |
302 | |
303 | private: |
304 | RegisterAggr Empty; |
305 | std::map<KeyType, RegisterAggr> Map; |
306 | |
307 | public: |
308 | using key_type = typename decltype(Map)::key_type; |
309 | using mapped_type = typename decltype(Map)::mapped_type; |
310 | using value_type = typename decltype(Map)::value_type; |
311 | }; |
312 | |
313 | raw_ostream &operator<<(raw_ostream &OS, const RegisterAggr &A); |
314 | |
315 | // Print the lane mask in a short form (or not at all if all bits are set). |
316 | struct PrintLaneMaskShort { |
317 | PrintLaneMaskShort(LaneBitmask M) : Mask(M) {} |
318 | LaneBitmask Mask; |
319 | }; |
320 | raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskShort &P); |
321 | |
322 | } // end namespace rdf |
323 | } // end namespace llvm |
324 | |
325 | namespace std { |
326 | |
327 | template <> struct hash<llvm::rdf::RegisterRef> { |
328 | size_t operator()(llvm::rdf::RegisterRef A) const { // |
329 | return A.hash(); |
330 | } |
331 | }; |
332 | |
333 | template <> struct hash<llvm::rdf::RegisterAggr> { |
334 | size_t operator()(const llvm::rdf::RegisterAggr &A) const { // |
335 | return A.hash(); |
336 | } |
337 | }; |
338 | |
339 | template <> struct equal_to<llvm::rdf::RegisterRef> { |
340 | constexpr equal_to(const llvm::rdf::PhysicalRegisterInfo &pri) : PRI(&pri) {} |
341 | |
342 | bool operator()(llvm::rdf::RegisterRef A, llvm::rdf::RegisterRef B) const { |
343 | return PRI->equal_to(A, B); |
344 | } |
345 | |
346 | private: |
347 | // Make it a pointer just in case. See comment in `less` below. |
348 | const llvm::rdf::PhysicalRegisterInfo *PRI; |
349 | }; |
350 | |
351 | template <> struct equal_to<llvm::rdf::RegisterAggr> { |
352 | bool operator()(const llvm::rdf::RegisterAggr &A, |
353 | const llvm::rdf::RegisterAggr &B) const { |
354 | return A == B; |
355 | } |
356 | }; |
357 | |
358 | template <> struct less<llvm::rdf::RegisterRef> { |
359 | constexpr less(const llvm::rdf::PhysicalRegisterInfo &pri) : PRI(&pri) {} |
360 | |
361 | bool operator()(llvm::rdf::RegisterRef A, llvm::rdf::RegisterRef B) const { |
362 | return PRI->less(A, B); |
363 | } |
364 | |
365 | private: |
366 | // Make it a pointer because apparently some versions of MSVC use std::swap |
367 | // on the std::less specialization. |
368 | const llvm::rdf::PhysicalRegisterInfo *PRI; |
369 | }; |
370 | |
371 | } // namespace std |
372 | |
373 | namespace llvm::rdf { |
374 | using RegisterSet = std::set<RegisterRef, std::less<RegisterRef>>; |
375 | } // namespace llvm::rdf |
376 | |
377 | #endif // LLVM_CODEGEN_RDFREGISTERS_H |
378 | |