1 | //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // This file contains the AArch64 / Cortex-A57 specific register allocation |
9 | // constraints for use by the PBQP register allocator. |
10 | // |
11 | // It is essentially a transcription of what is contained in |
12 | // AArch64A57FPLoadBalancing, which tries to use a balanced |
13 | // mix of odd and even D-registers when performing a critical sequence of |
14 | // independent, non-quadword FP/ASIMD floating-point multiply-accumulates. |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "AArch64PBQPRegAlloc.h" |
18 | #include "AArch64.h" |
19 | #include "AArch64RegisterInfo.h" |
20 | #include "llvm/CodeGen/LiveIntervals.h" |
21 | #include "llvm/CodeGen/MachineBasicBlock.h" |
22 | #include "llvm/CodeGen/MachineFunction.h" |
23 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
24 | #include "llvm/CodeGen/RegAllocPBQP.h" |
25 | #include "llvm/Support/Debug.h" |
26 | #include "llvm/Support/ErrorHandling.h" |
27 | #include "llvm/Support/raw_ostream.h" |
28 | |
29 | #define DEBUG_TYPE "aarch64-pbqp" |
30 | |
31 | using namespace llvm; |
32 | |
33 | namespace { |
34 | |
35 | #ifndef NDEBUG |
36 | bool isFPReg(unsigned reg) { |
37 | return AArch64::FPR32RegClass.contains(reg) || |
38 | AArch64::FPR64RegClass.contains(reg) || |
39 | AArch64::FPR128RegClass.contains(reg); |
40 | } |
41 | #endif |
42 | |
43 | bool isOdd(unsigned reg) { |
44 | switch (reg) { |
45 | default: |
46 | llvm_unreachable("Register is not from the expected class !" ); |
47 | case AArch64::S1: |
48 | case AArch64::S3: |
49 | case AArch64::S5: |
50 | case AArch64::S7: |
51 | case AArch64::S9: |
52 | case AArch64::S11: |
53 | case AArch64::S13: |
54 | case AArch64::S15: |
55 | case AArch64::S17: |
56 | case AArch64::S19: |
57 | case AArch64::S21: |
58 | case AArch64::S23: |
59 | case AArch64::S25: |
60 | case AArch64::S27: |
61 | case AArch64::S29: |
62 | case AArch64::S31: |
63 | case AArch64::D1: |
64 | case AArch64::D3: |
65 | case AArch64::D5: |
66 | case AArch64::D7: |
67 | case AArch64::D9: |
68 | case AArch64::D11: |
69 | case AArch64::D13: |
70 | case AArch64::D15: |
71 | case AArch64::D17: |
72 | case AArch64::D19: |
73 | case AArch64::D21: |
74 | case AArch64::D23: |
75 | case AArch64::D25: |
76 | case AArch64::D27: |
77 | case AArch64::D29: |
78 | case AArch64::D31: |
79 | case AArch64::Q1: |
80 | case AArch64::Q3: |
81 | case AArch64::Q5: |
82 | case AArch64::Q7: |
83 | case AArch64::Q9: |
84 | case AArch64::Q11: |
85 | case AArch64::Q13: |
86 | case AArch64::Q15: |
87 | case AArch64::Q17: |
88 | case AArch64::Q19: |
89 | case AArch64::Q21: |
90 | case AArch64::Q23: |
91 | case AArch64::Q25: |
92 | case AArch64::Q27: |
93 | case AArch64::Q29: |
94 | case AArch64::Q31: |
95 | return true; |
96 | case AArch64::S0: |
97 | case AArch64::S2: |
98 | case AArch64::S4: |
99 | case AArch64::S6: |
100 | case AArch64::S8: |
101 | case AArch64::S10: |
102 | case AArch64::S12: |
103 | case AArch64::S14: |
104 | case AArch64::S16: |
105 | case AArch64::S18: |
106 | case AArch64::S20: |
107 | case AArch64::S22: |
108 | case AArch64::S24: |
109 | case AArch64::S26: |
110 | case AArch64::S28: |
111 | case AArch64::S30: |
112 | case AArch64::D0: |
113 | case AArch64::D2: |
114 | case AArch64::D4: |
115 | case AArch64::D6: |
116 | case AArch64::D8: |
117 | case AArch64::D10: |
118 | case AArch64::D12: |
119 | case AArch64::D14: |
120 | case AArch64::D16: |
121 | case AArch64::D18: |
122 | case AArch64::D20: |
123 | case AArch64::D22: |
124 | case AArch64::D24: |
125 | case AArch64::D26: |
126 | case AArch64::D28: |
127 | case AArch64::D30: |
128 | case AArch64::Q0: |
129 | case AArch64::Q2: |
130 | case AArch64::Q4: |
131 | case AArch64::Q6: |
132 | case AArch64::Q8: |
133 | case AArch64::Q10: |
134 | case AArch64::Q12: |
135 | case AArch64::Q14: |
136 | case AArch64::Q16: |
137 | case AArch64::Q18: |
138 | case AArch64::Q20: |
139 | case AArch64::Q22: |
140 | case AArch64::Q24: |
141 | case AArch64::Q26: |
142 | case AArch64::Q28: |
143 | case AArch64::Q30: |
144 | return false; |
145 | |
146 | } |
147 | } |
148 | |
149 | bool haveSameParity(unsigned reg1, unsigned reg2) { |
150 | assert(isFPReg(reg1) && "Expecting an FP register for reg1" ); |
151 | assert(isFPReg(reg2) && "Expecting an FP register for reg2" ); |
152 | |
153 | return isOdd(reg: reg1) == isOdd(reg: reg2); |
154 | } |
155 | |
156 | } |
157 | |
158 | bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd, |
159 | unsigned Ra) { |
160 | if (Rd == Ra) |
161 | return false; |
162 | |
163 | LiveIntervals &LIs = G.getMetadata().LIS; |
164 | |
165 | if (Register::isPhysicalRegister(Reg: Rd) || Register::isPhysicalRegister(Reg: Ra)) { |
166 | LLVM_DEBUG(dbgs() << "Rd is a physical reg:" |
167 | << Register::isPhysicalRegister(Rd) << '\n'); |
168 | LLVM_DEBUG(dbgs() << "Ra is a physical reg:" |
169 | << Register::isPhysicalRegister(Ra) << '\n'); |
170 | return false; |
171 | } |
172 | |
173 | PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(VReg: Rd); |
174 | PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(VReg: Ra); |
175 | |
176 | const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = |
177 | &G.getNodeMetadata(NId: node1).getAllowedRegs(); |
178 | const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed = |
179 | &G.getNodeMetadata(NId: node2).getAllowedRegs(); |
180 | |
181 | PBQPRAGraph::EdgeId edge = G.findEdge(N1Id: node1, N2Id: node2); |
182 | |
183 | // The edge does not exist. Create one with the appropriate interference |
184 | // costs. |
185 | if (edge == G.invalidEdgeId()) { |
186 | const LiveInterval &ld = LIs.getInterval(Reg: Rd); |
187 | const LiveInterval &la = LIs.getInterval(Reg: Ra); |
188 | bool livesOverlap = ld.overlaps(other: la); |
189 | |
190 | PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1, |
191 | vRaAllowed->size() + 1, 0); |
192 | for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { |
193 | unsigned pRd = (*vRdAllowed)[i]; |
194 | for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { |
195 | unsigned pRa = (*vRaAllowed)[j]; |
196 | if (livesOverlap && TRI->regsOverlap(RegA: pRd, RegB: pRa)) |
197 | costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity(); |
198 | else |
199 | costs[i + 1][j + 1] = haveSameParity(reg1: pRd, reg2: pRa) ? 0.0 : 1.0; |
200 | } |
201 | } |
202 | G.addEdge(N1Id: node1, N2Id: node2, Costs: std::move(costs)); |
203 | return true; |
204 | } |
205 | |
206 | if (G.getEdgeNode1Id(EId: edge) == node2) { |
207 | std::swap(a&: node1, b&: node2); |
208 | std::swap(a&: vRdAllowed, b&: vRaAllowed); |
209 | } |
210 | |
211 | // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass)) |
212 | PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(EId: edge)); |
213 | for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { |
214 | unsigned pRd = (*vRdAllowed)[i]; |
215 | |
216 | // Get the maximum cost (excluding unallocatable reg) for same parity |
217 | // registers |
218 | PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); |
219 | for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { |
220 | unsigned pRa = (*vRaAllowed)[j]; |
221 | if (haveSameParity(reg1: pRd, reg2: pRa)) |
222 | if (costs[i + 1][j + 1] != |
223 | std::numeric_limits<PBQP::PBQPNum>::infinity() && |
224 | costs[i + 1][j + 1] > sameParityMax) |
225 | sameParityMax = costs[i + 1][j + 1]; |
226 | } |
227 | |
228 | // Ensure all registers with a different parity have a higher cost |
229 | // than sameParityMax |
230 | for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { |
231 | unsigned pRa = (*vRaAllowed)[j]; |
232 | if (!haveSameParity(reg1: pRd, reg2: pRa)) |
233 | if (sameParityMax > costs[i + 1][j + 1]) |
234 | costs[i + 1][j + 1] = sameParityMax + 1.0; |
235 | } |
236 | } |
237 | G.updateEdgeCosts(EId: edge, Costs: std::move(costs)); |
238 | |
239 | return true; |
240 | } |
241 | |
242 | void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd, |
243 | unsigned Ra) { |
244 | LiveIntervals &LIs = G.getMetadata().LIS; |
245 | |
246 | // Do some Chain management |
247 | if (Chains.count(key: Ra)) { |
248 | if (Rd != Ra) { |
249 | LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI) |
250 | << " to " << printReg(Rd, TRI) << '\n';); |
251 | Chains.remove(X: Ra); |
252 | Chains.insert(X: Rd); |
253 | } |
254 | } else { |
255 | LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI) |
256 | << '\n';); |
257 | Chains.insert(X: Rd); |
258 | } |
259 | |
260 | PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(VReg: Rd); |
261 | |
262 | const LiveInterval &ld = LIs.getInterval(Reg: Rd); |
263 | for (auto r : Chains) { |
264 | // Skip self |
265 | if (r == Rd) |
266 | continue; |
267 | |
268 | const LiveInterval &lr = LIs.getInterval(Reg: r); |
269 | if (ld.overlaps(other: lr)) { |
270 | const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = |
271 | &G.getNodeMetadata(NId: node1).getAllowedRegs(); |
272 | |
273 | PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(VReg: r); |
274 | const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed = |
275 | &G.getNodeMetadata(NId: node2).getAllowedRegs(); |
276 | |
277 | PBQPRAGraph::EdgeId edge = G.findEdge(N1Id: node1, N2Id: node2); |
278 | assert(edge != G.invalidEdgeId() && |
279 | "PBQP error ! The edge should exist !" ); |
280 | |
281 | LLVM_DEBUG(dbgs() << "Refining constraint !\n" ;); |
282 | |
283 | if (G.getEdgeNode1Id(EId: edge) == node2) { |
284 | std::swap(a&: node1, b&: node2); |
285 | std::swap(a&: vRdAllowed, b&: vRrAllowed); |
286 | } |
287 | |
288 | // Enforce that cost is higher with all other Chains of the same parity |
289 | PBQP::Matrix costs(G.getEdgeCosts(EId: edge)); |
290 | for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { |
291 | unsigned pRd = (*vRdAllowed)[i]; |
292 | |
293 | // Get the maximum cost (excluding unallocatable reg) for all other |
294 | // parity registers |
295 | PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); |
296 | for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { |
297 | unsigned pRa = (*vRrAllowed)[j]; |
298 | if (!haveSameParity(reg1: pRd, reg2: pRa)) |
299 | if (costs[i + 1][j + 1] != |
300 | std::numeric_limits<PBQP::PBQPNum>::infinity() && |
301 | costs[i + 1][j + 1] > sameParityMax) |
302 | sameParityMax = costs[i + 1][j + 1]; |
303 | } |
304 | |
305 | // Ensure all registers with same parity have a higher cost |
306 | // than sameParityMax |
307 | for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { |
308 | unsigned pRa = (*vRrAllowed)[j]; |
309 | if (haveSameParity(reg1: pRd, reg2: pRa)) |
310 | if (sameParityMax > costs[i + 1][j + 1]) |
311 | costs[i + 1][j + 1] = sameParityMax + 1.0; |
312 | } |
313 | } |
314 | G.updateEdgeCosts(EId: edge, Costs: std::move(costs)); |
315 | } |
316 | } |
317 | } |
318 | |
319 | static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, |
320 | const MachineInstr &MI) { |
321 | const LiveInterval &LI = LIs.getInterval(Reg: reg); |
322 | SlotIndex SI = LIs.getInstructionIndex(Instr: MI); |
323 | return LI.expiredAt(index: SI); |
324 | } |
325 | |
326 | void A57ChainingConstraint::apply(PBQPRAGraph &G) { |
327 | const MachineFunction &MF = G.getMetadata().MF; |
328 | LiveIntervals &LIs = G.getMetadata().LIS; |
329 | |
330 | TRI = MF.getSubtarget().getRegisterInfo(); |
331 | LLVM_DEBUG(MF.dump()); |
332 | |
333 | for (const auto &MBB: MF) { |
334 | Chains.clear(); // FIXME: really needed ? Could not work at MF level ? |
335 | |
336 | for (const auto &MI: MBB) { |
337 | |
338 | // Forget Chains which have expired |
339 | for (auto r : Chains) { |
340 | SmallVector<unsigned, 8> toDel; |
341 | if(regJustKilledBefore(LIs, reg: r, MI)) { |
342 | LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at " ; |
343 | MI.print(dbgs());); |
344 | toDel.push_back(Elt: r); |
345 | } |
346 | |
347 | while (!toDel.empty()) { |
348 | Chains.remove(X: toDel.back()); |
349 | toDel.pop_back(); |
350 | } |
351 | } |
352 | |
353 | switch (MI.getOpcode()) { |
354 | case AArch64::FMSUBSrrr: |
355 | case AArch64::FMADDSrrr: |
356 | case AArch64::FNMSUBSrrr: |
357 | case AArch64::FNMADDSrrr: |
358 | case AArch64::FMSUBDrrr: |
359 | case AArch64::FMADDDrrr: |
360 | case AArch64::FNMSUBDrrr: |
361 | case AArch64::FNMADDDrrr: { |
362 | Register Rd = MI.getOperand(i: 0).getReg(); |
363 | Register Ra = MI.getOperand(i: 3).getReg(); |
364 | |
365 | if (addIntraChainConstraint(G, Rd, Ra)) |
366 | addInterChainConstraint(G, Rd, Ra); |
367 | break; |
368 | } |
369 | |
370 | case AArch64::FMLAv2f32: |
371 | case AArch64::FMLSv2f32: { |
372 | Register Rd = MI.getOperand(i: 0).getReg(); |
373 | addInterChainConstraint(G, Rd, Ra: Rd); |
374 | break; |
375 | } |
376 | |
377 | default: |
378 | break; |
379 | } |
380 | } |
381 | } |
382 | } |
383 | |