1//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This file contains the AArch64 / Cortex-A57 specific register allocation
9// constraints for use by the PBQP register allocator.
10//
11// It is essentially a transcription of what is contained in
12// AArch64A57FPLoadBalancing, which tries to use a balanced
13// mix of odd and even D-registers when performing a critical sequence of
14// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15//===----------------------------------------------------------------------===//
16
17#include "AArch64PBQPRegAlloc.h"
18#include "AArch64.h"
19#include "AArch64RegisterInfo.h"
20#include "llvm/CodeGen/LiveIntervals.h"
21#include "llvm/CodeGen/MachineBasicBlock.h"
22#include "llvm/CodeGen/MachineFunction.h"
23#include "llvm/CodeGen/MachineRegisterInfo.h"
24#include "llvm/CodeGen/RegAllocPBQP.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/ErrorHandling.h"
27#include "llvm/Support/raw_ostream.h"
28
29#define DEBUG_TYPE "aarch64-pbqp"
30
31using namespace llvm;
32
33namespace {
34
35#ifndef NDEBUG
36bool isFPReg(unsigned reg) {
37 return AArch64::FPR32RegClass.contains(reg) ||
38 AArch64::FPR64RegClass.contains(reg) ||
39 AArch64::FPR128RegClass.contains(reg);
40}
41#endif
42
43bool isOdd(unsigned reg) {
44 switch (reg) {
45 default:
46 llvm_unreachable("Register is not from the expected class !");
47 case AArch64::S1:
48 case AArch64::S3:
49 case AArch64::S5:
50 case AArch64::S7:
51 case AArch64::S9:
52 case AArch64::S11:
53 case AArch64::S13:
54 case AArch64::S15:
55 case AArch64::S17:
56 case AArch64::S19:
57 case AArch64::S21:
58 case AArch64::S23:
59 case AArch64::S25:
60 case AArch64::S27:
61 case AArch64::S29:
62 case AArch64::S31:
63 case AArch64::D1:
64 case AArch64::D3:
65 case AArch64::D5:
66 case AArch64::D7:
67 case AArch64::D9:
68 case AArch64::D11:
69 case AArch64::D13:
70 case AArch64::D15:
71 case AArch64::D17:
72 case AArch64::D19:
73 case AArch64::D21:
74 case AArch64::D23:
75 case AArch64::D25:
76 case AArch64::D27:
77 case AArch64::D29:
78 case AArch64::D31:
79 case AArch64::Q1:
80 case AArch64::Q3:
81 case AArch64::Q5:
82 case AArch64::Q7:
83 case AArch64::Q9:
84 case AArch64::Q11:
85 case AArch64::Q13:
86 case AArch64::Q15:
87 case AArch64::Q17:
88 case AArch64::Q19:
89 case AArch64::Q21:
90 case AArch64::Q23:
91 case AArch64::Q25:
92 case AArch64::Q27:
93 case AArch64::Q29:
94 case AArch64::Q31:
95 return true;
96 case AArch64::S0:
97 case AArch64::S2:
98 case AArch64::S4:
99 case AArch64::S6:
100 case AArch64::S8:
101 case AArch64::S10:
102 case AArch64::S12:
103 case AArch64::S14:
104 case AArch64::S16:
105 case AArch64::S18:
106 case AArch64::S20:
107 case AArch64::S22:
108 case AArch64::S24:
109 case AArch64::S26:
110 case AArch64::S28:
111 case AArch64::S30:
112 case AArch64::D0:
113 case AArch64::D2:
114 case AArch64::D4:
115 case AArch64::D6:
116 case AArch64::D8:
117 case AArch64::D10:
118 case AArch64::D12:
119 case AArch64::D14:
120 case AArch64::D16:
121 case AArch64::D18:
122 case AArch64::D20:
123 case AArch64::D22:
124 case AArch64::D24:
125 case AArch64::D26:
126 case AArch64::D28:
127 case AArch64::D30:
128 case AArch64::Q0:
129 case AArch64::Q2:
130 case AArch64::Q4:
131 case AArch64::Q6:
132 case AArch64::Q8:
133 case AArch64::Q10:
134 case AArch64::Q12:
135 case AArch64::Q14:
136 case AArch64::Q16:
137 case AArch64::Q18:
138 case AArch64::Q20:
139 case AArch64::Q22:
140 case AArch64::Q24:
141 case AArch64::Q26:
142 case AArch64::Q28:
143 case AArch64::Q30:
144 return false;
145
146 }
147}
148
149bool haveSameParity(unsigned reg1, unsigned reg2) {
150 assert(isFPReg(reg1) && "Expecting an FP register for reg1");
151 assert(isFPReg(reg2) && "Expecting an FP register for reg2");
152
153 return isOdd(reg: reg1) == isOdd(reg: reg2);
154}
155
156}
157
158bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
159 unsigned Ra) {
160 if (Rd == Ra)
161 return false;
162
163 LiveIntervals &LIs = G.getMetadata().LIS;
164
165 if (Register::isPhysicalRegister(Reg: Rd) || Register::isPhysicalRegister(Reg: Ra)) {
166 LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
167 << Register::isPhysicalRegister(Rd) << '\n');
168 LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
169 << Register::isPhysicalRegister(Ra) << '\n');
170 return false;
171 }
172
173 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(VReg: Rd);
174 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(VReg: Ra);
175
176 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
177 &G.getNodeMetadata(NId: node1).getAllowedRegs();
178 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
179 &G.getNodeMetadata(NId: node2).getAllowedRegs();
180
181 PBQPRAGraph::EdgeId edge = G.findEdge(N1Id: node1, N2Id: node2);
182
183 // The edge does not exist. Create one with the appropriate interference
184 // costs.
185 if (edge == G.invalidEdgeId()) {
186 const LiveInterval &ld = LIs.getInterval(Reg: Rd);
187 const LiveInterval &la = LIs.getInterval(Reg: Ra);
188 bool livesOverlap = ld.overlaps(other: la);
189
190 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
191 vRaAllowed->size() + 1, 0);
192 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
193 unsigned pRd = (*vRdAllowed)[i];
194 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
195 unsigned pRa = (*vRaAllowed)[j];
196 if (livesOverlap && TRI->regsOverlap(RegA: pRd, RegB: pRa))
197 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
198 else
199 costs[i + 1][j + 1] = haveSameParity(reg1: pRd, reg2: pRa) ? 0.0 : 1.0;
200 }
201 }
202 G.addEdge(N1Id: node1, N2Id: node2, Costs: std::move(costs));
203 return true;
204 }
205
206 if (G.getEdgeNode1Id(EId: edge) == node2) {
207 std::swap(a&: node1, b&: node2);
208 std::swap(a&: vRdAllowed, b&: vRaAllowed);
209 }
210
211 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
212 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(EId: edge));
213 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
214 unsigned pRd = (*vRdAllowed)[i];
215
216 // Get the maximum cost (excluding unallocatable reg) for same parity
217 // registers
218 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
219 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
220 unsigned pRa = (*vRaAllowed)[j];
221 if (haveSameParity(reg1: pRd, reg2: pRa))
222 if (costs[i + 1][j + 1] !=
223 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
224 costs[i + 1][j + 1] > sameParityMax)
225 sameParityMax = costs[i + 1][j + 1];
226 }
227
228 // Ensure all registers with a different parity have a higher cost
229 // than sameParityMax
230 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
231 unsigned pRa = (*vRaAllowed)[j];
232 if (!haveSameParity(reg1: pRd, reg2: pRa))
233 if (sameParityMax > costs[i + 1][j + 1])
234 costs[i + 1][j + 1] = sameParityMax + 1.0;
235 }
236 }
237 G.updateEdgeCosts(EId: edge, Costs: std::move(costs));
238
239 return true;
240}
241
242void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
243 unsigned Ra) {
244 LiveIntervals &LIs = G.getMetadata().LIS;
245
246 // Do some Chain management
247 if (Chains.count(key: Ra)) {
248 if (Rd != Ra) {
249 LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
250 << " to " << printReg(Rd, TRI) << '\n';);
251 Chains.remove(X: Ra);
252 Chains.insert(X: Rd);
253 }
254 } else {
255 LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
256 << '\n';);
257 Chains.insert(X: Rd);
258 }
259
260 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(VReg: Rd);
261
262 const LiveInterval &ld = LIs.getInterval(Reg: Rd);
263 for (auto r : Chains) {
264 // Skip self
265 if (r == Rd)
266 continue;
267
268 const LiveInterval &lr = LIs.getInterval(Reg: r);
269 if (ld.overlaps(other: lr)) {
270 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
271 &G.getNodeMetadata(NId: node1).getAllowedRegs();
272
273 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(VReg: r);
274 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
275 &G.getNodeMetadata(NId: node2).getAllowedRegs();
276
277 PBQPRAGraph::EdgeId edge = G.findEdge(N1Id: node1, N2Id: node2);
278 assert(edge != G.invalidEdgeId() &&
279 "PBQP error ! The edge should exist !");
280
281 LLVM_DEBUG(dbgs() << "Refining constraint !\n";);
282
283 if (G.getEdgeNode1Id(EId: edge) == node2) {
284 std::swap(a&: node1, b&: node2);
285 std::swap(a&: vRdAllowed, b&: vRrAllowed);
286 }
287
288 // Enforce that cost is higher with all other Chains of the same parity
289 PBQP::Matrix costs(G.getEdgeCosts(EId: edge));
290 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
291 unsigned pRd = (*vRdAllowed)[i];
292
293 // Get the maximum cost (excluding unallocatable reg) for all other
294 // parity registers
295 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
296 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
297 unsigned pRa = (*vRrAllowed)[j];
298 if (!haveSameParity(reg1: pRd, reg2: pRa))
299 if (costs[i + 1][j + 1] !=
300 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
301 costs[i + 1][j + 1] > sameParityMax)
302 sameParityMax = costs[i + 1][j + 1];
303 }
304
305 // Ensure all registers with same parity have a higher cost
306 // than sameParityMax
307 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
308 unsigned pRa = (*vRrAllowed)[j];
309 if (haveSameParity(reg1: pRd, reg2: pRa))
310 if (sameParityMax > costs[i + 1][j + 1])
311 costs[i + 1][j + 1] = sameParityMax + 1.0;
312 }
313 }
314 G.updateEdgeCosts(EId: edge, Costs: std::move(costs));
315 }
316 }
317}
318
319static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
320 const MachineInstr &MI) {
321 const LiveInterval &LI = LIs.getInterval(Reg: reg);
322 SlotIndex SI = LIs.getInstructionIndex(Instr: MI);
323 return LI.expiredAt(index: SI);
324}
325
326void A57ChainingConstraint::apply(PBQPRAGraph &G) {
327 const MachineFunction &MF = G.getMetadata().MF;
328 LiveIntervals &LIs = G.getMetadata().LIS;
329
330 TRI = MF.getSubtarget().getRegisterInfo();
331 LLVM_DEBUG(MF.dump());
332
333 for (const auto &MBB: MF) {
334 Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
335
336 for (const auto &MI: MBB) {
337
338 // Forget Chains which have expired
339 for (auto r : Chains) {
340 SmallVector<unsigned, 8> toDel;
341 if(regJustKilledBefore(LIs, reg: r, MI)) {
342 LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
343 MI.print(dbgs()););
344 toDel.push_back(Elt: r);
345 }
346
347 while (!toDel.empty()) {
348 Chains.remove(X: toDel.back());
349 toDel.pop_back();
350 }
351 }
352
353 switch (MI.getOpcode()) {
354 case AArch64::FMSUBSrrr:
355 case AArch64::FMADDSrrr:
356 case AArch64::FNMSUBSrrr:
357 case AArch64::FNMADDSrrr:
358 case AArch64::FMSUBDrrr:
359 case AArch64::FMADDDrrr:
360 case AArch64::FNMSUBDrrr:
361 case AArch64::FNMADDDrrr: {
362 Register Rd = MI.getOperand(i: 0).getReg();
363 Register Ra = MI.getOperand(i: 3).getReg();
364
365 if (addIntraChainConstraint(G, Rd, Ra))
366 addInterChainConstraint(G, Rd, Ra);
367 break;
368 }
369
370 case AArch64::FMLAv2f32:
371 case AArch64::FMLSv2f32: {
372 Register Rd = MI.getOperand(i: 0).getReg();
373 addInterChainConstraint(G, Rd, Ra: Rd);
374 break;
375 }
376
377 default:
378 break;
379 }
380 }
381 }
382}
383

source code of llvm/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp