1//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to NVPTX assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXAsmPrinter.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "MCTargetDesc/NVPTXInstPrinter.h"
17#include "MCTargetDesc/NVPTXMCAsmInfo.h"
18#include "MCTargetDesc/NVPTXTargetStreamer.h"
19#include "NVPTX.h"
20#include "NVPTXMCExpr.h"
21#include "NVPTXMachineFunctionInfo.h"
22#include "NVPTXRegisterInfo.h"
23#include "NVPTXSubtarget.h"
24#include "NVPTXTargetMachine.h"
25#include "NVPTXUtilities.h"
26#include "TargetInfo/NVPTXTargetInfo.h"
27#include "cl_common_defines.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/DenseSet.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/ADT/StringExtras.h"
35#include "llvm/ADT/StringRef.h"
36#include "llvm/ADT/Twine.h"
37#include "llvm/Analysis/ConstantFolding.h"
38#include "llvm/CodeGen/Analysis.h"
39#include "llvm/CodeGen/MachineBasicBlock.h"
40#include "llvm/CodeGen/MachineFrameInfo.h"
41#include "llvm/CodeGen/MachineFunction.h"
42#include "llvm/CodeGen/MachineInstr.h"
43#include "llvm/CodeGen/MachineLoopInfo.h"
44#include "llvm/CodeGen/MachineModuleInfo.h"
45#include "llvm/CodeGen/MachineOperand.h"
46#include "llvm/CodeGen/MachineRegisterInfo.h"
47#include "llvm/CodeGen/TargetRegisterInfo.h"
48#include "llvm/CodeGen/ValueTypes.h"
49#include "llvm/CodeGenTypes/MachineValueType.h"
50#include "llvm/IR/Attributes.h"
51#include "llvm/IR/BasicBlock.h"
52#include "llvm/IR/Constant.h"
53#include "llvm/IR/Constants.h"
54#include "llvm/IR/DataLayout.h"
55#include "llvm/IR/DebugInfo.h"
56#include "llvm/IR/DebugInfoMetadata.h"
57#include "llvm/IR/DebugLoc.h"
58#include "llvm/IR/DerivedTypes.h"
59#include "llvm/IR/Function.h"
60#include "llvm/IR/GlobalAlias.h"
61#include "llvm/IR/GlobalValue.h"
62#include "llvm/IR/GlobalVariable.h"
63#include "llvm/IR/Instruction.h"
64#include "llvm/IR/LLVMContext.h"
65#include "llvm/IR/Module.h"
66#include "llvm/IR/Operator.h"
67#include "llvm/IR/Type.h"
68#include "llvm/IR/User.h"
69#include "llvm/MC/MCExpr.h"
70#include "llvm/MC/MCInst.h"
71#include "llvm/MC/MCInstrDesc.h"
72#include "llvm/MC/MCStreamer.h"
73#include "llvm/MC/MCSymbol.h"
74#include "llvm/MC/TargetRegistry.h"
75#include "llvm/Support/Casting.h"
76#include "llvm/Support/CommandLine.h"
77#include "llvm/Support/Endian.h"
78#include "llvm/Support/ErrorHandling.h"
79#include "llvm/Support/NativeFormatting.h"
80#include "llvm/Support/Path.h"
81#include "llvm/Support/raw_ostream.h"
82#include "llvm/Target/TargetLoweringObjectFile.h"
83#include "llvm/Target/TargetMachine.h"
84#include "llvm/TargetParser/Triple.h"
85#include "llvm/Transforms/Utils/UnrollLoop.h"
86#include <cassert>
87#include <cstdint>
88#include <cstring>
89#include <new>
90#include <string>
91#include <utility>
92#include <vector>
93
94using namespace llvm;
95
96static cl::opt<bool>
97 LowerCtorDtor("nvptx-lower-global-ctor-dtor",
98 cl::desc("Lower GPU ctor / dtors to globals on the device."),
99 cl::init(Val: false), cl::Hidden);
100
101#define DEPOTNAME "__local_depot"
102
103/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
104/// depends.
105static void
106DiscoverDependentGlobals(const Value *V,
107 DenseSet<const GlobalVariable *> &Globals) {
108 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: V))
109 Globals.insert(V: GV);
110 else {
111 if (const User *U = dyn_cast<User>(Val: V)) {
112 for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
113 DiscoverDependentGlobals(V: U->getOperand(i), Globals);
114 }
115 }
116 }
117}
118
119/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
120/// instances to be emitted, but only after any dependents have been added
121/// first.s
122static void
123VisitGlobalVariableForEmission(const GlobalVariable *GV,
124 SmallVectorImpl<const GlobalVariable *> &Order,
125 DenseSet<const GlobalVariable *> &Visited,
126 DenseSet<const GlobalVariable *> &Visiting) {
127 // Have we already visited this one?
128 if (Visited.count(V: GV))
129 return;
130
131 // Do we have a circular dependency?
132 if (!Visiting.insert(V: GV).second)
133 report_fatal_error(reason: "Circular dependency found in global variable set");
134
135 // Make sure we visit all dependents first
136 DenseSet<const GlobalVariable *> Others;
137 for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
138 DiscoverDependentGlobals(V: GV->getOperand(i_nocapture: i), Globals&: Others);
139
140 for (const GlobalVariable *GV : Others)
141 VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
142
143 // Now we can visit ourself
144 Order.push_back(Elt: GV);
145 Visited.insert(V: GV);
146 Visiting.erase(V: GV);
147}
148
149void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
150 NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
151 getSubtargetInfo().getFeatureBits());
152
153 MCInst Inst;
154 lowerToMCInst(MI, OutMI&: Inst);
155 EmitToStreamer(S&: *OutStreamer, Inst);
156}
157
158// Handle symbol backtracking for targets that do not support image handles
159bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
160 unsigned OpNo, MCOperand &MCOp) {
161 const MachineOperand &MO = MI->getOperand(i: OpNo);
162 const MCInstrDesc &MCID = MI->getDesc();
163
164 if (MCID.TSFlags & NVPTXII::IsTexFlag) {
165 // This is a texture fetch, so operand 4 is a texref and operand 5 is
166 // a samplerref
167 if (OpNo == 4 && MO.isImm()) {
168 lowerImageHandleSymbol(Index: MO.getImm(), MCOp);
169 return true;
170 }
171 if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
172 lowerImageHandleSymbol(Index: MO.getImm(), MCOp);
173 return true;
174 }
175
176 return false;
177 } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
178 unsigned VecSize =
179 1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
180
181 // For a surface load of vector size N, the Nth operand will be the surfref
182 if (OpNo == VecSize && MO.isImm()) {
183 lowerImageHandleSymbol(Index: MO.getImm(), MCOp);
184 return true;
185 }
186
187 return false;
188 } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
189 // This is a surface store, so operand 0 is a surfref
190 if (OpNo == 0 && MO.isImm()) {
191 lowerImageHandleSymbol(Index: MO.getImm(), MCOp);
192 return true;
193 }
194
195 return false;
196 } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
197 // This is a query, so operand 1 is a surfref/texref
198 if (OpNo == 1 && MO.isImm()) {
199 lowerImageHandleSymbol(Index: MO.getImm(), MCOp);
200 return true;
201 }
202
203 return false;
204 }
205
206 return false;
207}
208
209void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
210 // Ewwww
211 LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
212 NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
213 const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
214 const char *Sym = MFI->getImageHandleSymbol(Idx: Index);
215 StringRef SymName = nvTM.getStrPool().save(S: Sym);
216 MCOp = GetSymbolRef(Symbol: OutContext.getOrCreateSymbol(Name: SymName));
217}
218
219void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
220 OutMI.setOpcode(MI->getOpcode());
221 // Special: Do not mangle symbol operand of CALL_PROTOTYPE
222 if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
223 const MachineOperand &MO = MI->getOperand(i: 0);
224 OutMI.addOperand(Op: GetSymbolRef(
225 Symbol: OutContext.getOrCreateSymbol(Name: Twine(MO.getSymbolName()))));
226 return;
227 }
228
229 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
230 for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
231 const MachineOperand &MO = MI->getOperand(i);
232
233 MCOperand MCOp;
234 if (!STI.hasImageHandles()) {
235 if (lowerImageHandleOperand(MI, OpNo: i, MCOp)) {
236 OutMI.addOperand(Op: MCOp);
237 continue;
238 }
239 }
240
241 if (lowerOperand(MO, MCOp))
242 OutMI.addOperand(Op: MCOp);
243 }
244}
245
246bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
247 MCOperand &MCOp) {
248 switch (MO.getType()) {
249 default: llvm_unreachable("unknown operand type");
250 case MachineOperand::MO_Register:
251 MCOp = MCOperand::createReg(Reg: encodeVirtualRegister(Reg: MO.getReg()));
252 break;
253 case MachineOperand::MO_Immediate:
254 MCOp = MCOperand::createImm(Val: MO.getImm());
255 break;
256 case MachineOperand::MO_MachineBasicBlock:
257 MCOp = MCOperand::createExpr(Val: MCSymbolRefExpr::create(
258 Symbol: MO.getMBB()->getSymbol(), Ctx&: OutContext));
259 break;
260 case MachineOperand::MO_ExternalSymbol:
261 MCOp = GetSymbolRef(Symbol: GetExternalSymbolSymbol(Sym: MO.getSymbolName()));
262 break;
263 case MachineOperand::MO_GlobalAddress:
264 MCOp = GetSymbolRef(Symbol: getSymbol(GV: MO.getGlobal()));
265 break;
266 case MachineOperand::MO_FPImmediate: {
267 const ConstantFP *Cnt = MO.getFPImm();
268 const APFloat &Val = Cnt->getValueAPF();
269
270 switch (Cnt->getType()->getTypeID()) {
271 default: report_fatal_error(reason: "Unsupported FP type"); break;
272 case Type::HalfTyID:
273 MCOp = MCOperand::createExpr(
274 Val: NVPTXFloatMCExpr::createConstantFPHalf(Flt: Val, Ctx&: OutContext));
275 break;
276 case Type::BFloatTyID:
277 MCOp = MCOperand::createExpr(
278 Val: NVPTXFloatMCExpr::createConstantBFPHalf(Flt: Val, Ctx&: OutContext));
279 break;
280 case Type::FloatTyID:
281 MCOp = MCOperand::createExpr(
282 Val: NVPTXFloatMCExpr::createConstantFPSingle(Flt: Val, Ctx&: OutContext));
283 break;
284 case Type::DoubleTyID:
285 MCOp = MCOperand::createExpr(
286 Val: NVPTXFloatMCExpr::createConstantFPDouble(Flt: Val, Ctx&: OutContext));
287 break;
288 }
289 break;
290 }
291 }
292 return true;
293}
294
295unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
296 if (Register::isVirtualRegister(Reg)) {
297 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
298
299 DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
300 unsigned RegNum = RegMap[Reg];
301
302 // Encode the register class in the upper 4 bits
303 // Must be kept in sync with NVPTXInstPrinter::printRegName
304 unsigned Ret = 0;
305 if (RC == &NVPTX::Int1RegsRegClass) {
306 Ret = (1 << 28);
307 } else if (RC == &NVPTX::Int16RegsRegClass) {
308 Ret = (2 << 28);
309 } else if (RC == &NVPTX::Int32RegsRegClass) {
310 Ret = (3 << 28);
311 } else if (RC == &NVPTX::Int64RegsRegClass) {
312 Ret = (4 << 28);
313 } else if (RC == &NVPTX::Float32RegsRegClass) {
314 Ret = (5 << 28);
315 } else if (RC == &NVPTX::Float64RegsRegClass) {
316 Ret = (6 << 28);
317 } else {
318 report_fatal_error(reason: "Bad register class");
319 }
320
321 // Insert the vreg number
322 Ret |= (RegNum & 0x0FFFFFFF);
323 return Ret;
324 } else {
325 // Some special-use registers are actually physical registers.
326 // Encode this as the register class ID of 0 and the real register ID.
327 return Reg & 0x0FFFFFFF;
328 }
329}
330
331MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
332 const MCExpr *Expr;
333 Expr = MCSymbolRefExpr::create(Symbol, Kind: MCSymbolRefExpr::VK_None,
334 Ctx&: OutContext);
335 return MCOperand::createExpr(Val: Expr);
336}
337
338static bool ShouldPassAsArray(Type *Ty) {
339 return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(Bitwidth: 128) ||
340 Ty->isHalfTy() || Ty->isBFloatTy();
341}
342
343void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
344 const DataLayout &DL = getDataLayout();
345 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
346 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
347
348 Type *Ty = F->getReturnType();
349
350 bool isABI = (STI.getSmVersion() >= 20);
351
352 if (Ty->getTypeID() == Type::VoidTyID)
353 return;
354 O << " (";
355
356 if (isABI) {
357 if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
358 !ShouldPassAsArray(Ty)) {
359 unsigned size = 0;
360 if (auto *ITy = dyn_cast<IntegerType>(Val: Ty)) {
361 size = ITy->getBitWidth();
362 } else {
363 assert(Ty->isFloatingPointTy() && "Floating point type expected here");
364 size = Ty->getPrimitiveSizeInBits();
365 }
366 size = promoteScalarArgumentSize(size);
367 O << ".param .b" << size << " func_retval0";
368 } else if (isa<PointerType>(Val: Ty)) {
369 O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
370 << " func_retval0";
371 } else if (ShouldPassAsArray(Ty)) {
372 unsigned totalsz = DL.getTypeAllocSize(Ty);
373 unsigned retAlignment = 0;
374 if (!getAlign(*F, index: 0, retAlignment))
375 retAlignment = TLI->getFunctionParamOptimizedAlign(F, ArgTy: Ty, DL).value();
376 O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
377 << "]";
378 } else
379 llvm_unreachable("Unknown return type");
380 } else {
381 SmallVector<EVT, 16> vtparts;
382 ComputeValueVTs(TLI: *TLI, DL, Ty, ValueVTs&: vtparts);
383 unsigned idx = 0;
384 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
385 unsigned elems = 1;
386 EVT elemtype = vtparts[i];
387 if (vtparts[i].isVector()) {
388 elems = vtparts[i].getVectorNumElements();
389 elemtype = vtparts[i].getVectorElementType();
390 }
391
392 for (unsigned j = 0, je = elems; j != je; ++j) {
393 unsigned sz = elemtype.getSizeInBits();
394 if (elemtype.isInteger())
395 sz = promoteScalarArgumentSize(size: sz);
396 O << ".reg .b" << sz << " func_retval" << idx;
397 if (j < je - 1)
398 O << ", ";
399 ++idx;
400 }
401 if (i < e - 1)
402 O << ", ";
403 }
404 }
405 O << ") ";
406}
407
408void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
409 raw_ostream &O) {
410 const Function &F = MF.getFunction();
411 printReturnValStr(F: &F, O);
412}
413
414// Return true if MBB is the header of a loop marked with
415// llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
416bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
417 const MachineBasicBlock &MBB) const {
418 MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
419 // We insert .pragma "nounroll" only to the loop header.
420 if (!LI.isLoopHeader(BB: &MBB))
421 return false;
422
423 // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
424 // we iterate through each back edge of the loop with header MBB, and check
425 // whether its metadata contains llvm.loop.unroll.disable.
426 for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
427 if (LI.getLoopFor(BB: PMBB) != LI.getLoopFor(BB: &MBB)) {
428 // Edges from other loops to MBB are not back edges.
429 continue;
430 }
431 if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
432 if (MDNode *LoopID =
433 PBB->getTerminator()->getMetadata(KindID: LLVMContext::MD_loop)) {
434 if (GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.disable"))
435 return true;
436 if (MDNode *UnrollCountMD =
437 GetUnrollMetadata(LoopID, Name: "llvm.loop.unroll.count")) {
438 if (mdconst::extract<ConstantInt>(MD: UnrollCountMD->getOperand(I: 1))
439 ->isOne())
440 return true;
441 }
442 }
443 }
444 }
445 return false;
446}
447
448void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
449 AsmPrinter::emitBasicBlockStart(MBB);
450 if (isLoopHeaderOfNoUnroll(MBB))
451 OutStreamer->emitRawText(String: StringRef("\t.pragma \"nounroll\";\n"));
452}
453
454void NVPTXAsmPrinter::emitFunctionEntryLabel() {
455 SmallString<128> Str;
456 raw_svector_ostream O(Str);
457
458 if (!GlobalsEmitted) {
459 emitGlobals(M: *MF->getFunction().getParent());
460 GlobalsEmitted = true;
461 }
462
463 // Set up
464 MRI = &MF->getRegInfo();
465 F = &MF->getFunction();
466 emitLinkageDirective(V: F, O);
467 if (isKernelFunction(*F))
468 O << ".entry ";
469 else {
470 O << ".func ";
471 printReturnValStr(MF: *MF, O);
472 }
473
474 CurrentFnSym->print(OS&: O, MAI);
475
476 emitFunctionParamList(F, O);
477 O << "\n";
478
479 if (isKernelFunction(*F))
480 emitKernelFunctionDirectives(F: *F, O);
481
482 if (shouldEmitPTXNoReturn(V: F, TM))
483 O << ".noreturn";
484
485 OutStreamer->emitRawText(String: O.str());
486
487 VRegMapping.clear();
488 // Emit open brace for function body.
489 OutStreamer->emitRawText(String: StringRef("{\n"));
490 setAndEmitFunctionVirtualRegisters(*MF);
491 // Emit initial .loc debug directive for correct relocation symbol data.
492 if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
493 assert(SP->getUnit());
494 if (!SP->getUnit()->isDebugDirectivesOnly() && MMI && MMI->hasDebugInfo())
495 emitInitialRawDwarfLocDirective(MF: *MF);
496 }
497}
498
499bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
500 bool Result = AsmPrinter::runOnMachineFunction(MF&: F);
501 // Emit closing brace for the body of function F.
502 // The closing brace must be emitted here because we need to emit additional
503 // debug labels/data after the last basic block.
504 // We need to emit the closing brace here because we don't have function that
505 // finished emission of the function body.
506 OutStreamer->emitRawText(String: StringRef("}\n"));
507 return Result;
508}
509
510void NVPTXAsmPrinter::emitFunctionBodyStart() {
511 SmallString<128> Str;
512 raw_svector_ostream O(Str);
513 emitDemotedVars(&MF->getFunction(), O);
514 OutStreamer->emitRawText(String: O.str());
515}
516
517void NVPTXAsmPrinter::emitFunctionBodyEnd() {
518 VRegMapping.clear();
519}
520
521const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
522 SmallString<128> Str;
523 raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
524 return OutContext.getOrCreateSymbol(Name: Str);
525}
526
527void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
528 Register RegNo = MI->getOperand(i: 0).getReg();
529 if (RegNo.isVirtual()) {
530 OutStreamer->AddComment(T: Twine("implicit-def: ") +
531 getVirtualRegisterName(RegNo));
532 } else {
533 const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
534 OutStreamer->AddComment(T: Twine("implicit-def: ") +
535 STI.getRegisterInfo()->getName(RegNo));
536 }
537 OutStreamer->addBlankLine();
538}
539
540void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
541 raw_ostream &O) const {
542 // If the NVVM IR has some of reqntid* specified, then output
543 // the reqntid directive, and set the unspecified ones to 1.
544 // If none of Reqntid* is specified, don't output reqntid directive.
545 unsigned Reqntidx, Reqntidy, Reqntidz;
546 Reqntidx = Reqntidy = Reqntidz = 1;
547 bool ReqSpecified = false;
548 ReqSpecified |= getReqNTIDx(F, Reqntidx);
549 ReqSpecified |= getReqNTIDy(F, Reqntidy);
550 ReqSpecified |= getReqNTIDz(F, Reqntidz);
551
552 if (ReqSpecified)
553 O << ".reqntid " << Reqntidx << ", " << Reqntidy << ", " << Reqntidz
554 << "\n";
555
556 // If the NVVM IR has some of maxntid* specified, then output
557 // the maxntid directive, and set the unspecified ones to 1.
558 // If none of maxntid* is specified, don't output maxntid directive.
559 unsigned Maxntidx, Maxntidy, Maxntidz;
560 Maxntidx = Maxntidy = Maxntidz = 1;
561 bool MaxSpecified = false;
562 MaxSpecified |= getMaxNTIDx(F, Maxntidx);
563 MaxSpecified |= getMaxNTIDy(F, Maxntidy);
564 MaxSpecified |= getMaxNTIDz(F, Maxntidz);
565
566 if (MaxSpecified)
567 O << ".maxntid " << Maxntidx << ", " << Maxntidy << ", " << Maxntidz
568 << "\n";
569
570 unsigned Mincta = 0;
571 if (getMinCTASm(F, Mincta))
572 O << ".minnctapersm " << Mincta << "\n";
573
574 unsigned Maxnreg = 0;
575 if (getMaxNReg(F, Maxnreg))
576 O << ".maxnreg " << Maxnreg << "\n";
577
578 // .maxclusterrank directive requires SM_90 or higher, make sure that we
579 // filter it out for lower SM versions, as it causes a hard ptxas crash.
580 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
581 const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
582 unsigned Maxclusterrank = 0;
583 if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
584 O << ".maxclusterrank " << Maxclusterrank << "\n";
585}
586
587std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
588 const TargetRegisterClass *RC = MRI->getRegClass(Reg);
589
590 std::string Name;
591 raw_string_ostream NameStr(Name);
592
593 VRegRCMap::const_iterator I = VRegMapping.find(Val: RC);
594 assert(I != VRegMapping.end() && "Bad register class");
595 const DenseMap<unsigned, unsigned> &RegMap = I->second;
596
597 VRegMap::const_iterator VI = RegMap.find(Val: Reg);
598 assert(VI != RegMap.end() && "Bad virtual register");
599 unsigned MappedVR = VI->second;
600
601 NameStr << getNVPTXRegClassStr(RC) << MappedVR;
602
603 NameStr.flush();
604 return Name;
605}
606
607void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
608 raw_ostream &O) {
609 O << getVirtualRegisterName(Reg: vr);
610}
611
612void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
613 raw_ostream &O) {
614 const Function *F = dyn_cast_or_null<Function>(Val: GA->getAliaseeObject());
615 if (!F || isKernelFunction(*F) || F->isDeclaration())
616 report_fatal_error(
617 reason: "NVPTX aliasee must be a non-kernel function definition");
618
619 if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
620 GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
621 report_fatal_error(reason: "NVPTX aliasee must not be '.weak'");
622
623 emitDeclarationWithName(F, getSymbol(GV: GA), O);
624}
625
626void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
627 emitDeclarationWithName(F, getSymbol(GV: F), O);
628}
629
630void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
631 raw_ostream &O) {
632 emitLinkageDirective(V: F, O);
633 if (isKernelFunction(*F))
634 O << ".entry ";
635 else
636 O << ".func ";
637 printReturnValStr(F, O);
638 S->print(OS&: O, MAI);
639 O << "\n";
640 emitFunctionParamList(F, O);
641 O << "\n";
642 if (shouldEmitPTXNoReturn(V: F, TM))
643 O << ".noreturn";
644 O << ";\n";
645}
646
647static bool usedInGlobalVarDef(const Constant *C) {
648 if (!C)
649 return false;
650
651 if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Val: C)) {
652 return GV->getName() != "llvm.used";
653 }
654
655 for (const User *U : C->users())
656 if (const Constant *C = dyn_cast<Constant>(Val: U))
657 if (usedInGlobalVarDef(C))
658 return true;
659
660 return false;
661}
662
663static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
664 if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(Val: U)) {
665 if (othergv->getName() == "llvm.used")
666 return true;
667 }
668
669 if (const Instruction *instr = dyn_cast<Instruction>(Val: U)) {
670 if (instr->getParent() && instr->getParent()->getParent()) {
671 const Function *curFunc = instr->getParent()->getParent();
672 if (oneFunc && (curFunc != oneFunc))
673 return false;
674 oneFunc = curFunc;
675 return true;
676 } else
677 return false;
678 }
679
680 for (const User *UU : U->users())
681 if (!usedInOneFunc(U: UU, oneFunc))
682 return false;
683
684 return true;
685}
686
687/* Find out if a global variable can be demoted to local scope.
688 * Currently, this is valid for CUDA shared variables, which have local
689 * scope and global lifetime. So the conditions to check are :
690 * 1. Is the global variable in shared address space?
691 * 2. Does it have local linkage?
692 * 3. Is the global variable referenced only in one function?
693 */
694static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
695 if (!gv->hasLocalLinkage())
696 return false;
697 PointerType *Pty = gv->getType();
698 if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
699 return false;
700
701 const Function *oneFunc = nullptr;
702
703 bool flag = usedInOneFunc(U: gv, oneFunc);
704 if (!flag)
705 return false;
706 if (!oneFunc)
707 return false;
708 f = oneFunc;
709 return true;
710}
711
712static bool useFuncSeen(const Constant *C,
713 DenseMap<const Function *, bool> &seenMap) {
714 for (const User *U : C->users()) {
715 if (const Constant *cu = dyn_cast<Constant>(Val: U)) {
716 if (useFuncSeen(C: cu, seenMap))
717 return true;
718 } else if (const Instruction *I = dyn_cast<Instruction>(Val: U)) {
719 const BasicBlock *bb = I->getParent();
720 if (!bb)
721 continue;
722 const Function *caller = bb->getParent();
723 if (!caller)
724 continue;
725 if (seenMap.contains(Val: caller))
726 return true;
727 }
728 }
729 return false;
730}
731
732void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
733 DenseMap<const Function *, bool> seenMap;
734 for (const Function &F : M) {
735 if (F.getAttributes().hasFnAttr(Kind: "nvptx-libcall-callee")) {
736 emitDeclaration(F: &F, O);
737 continue;
738 }
739
740 if (F.isDeclaration()) {
741 if (F.use_empty())
742 continue;
743 if (F.getIntrinsicID())
744 continue;
745 emitDeclaration(F: &F, O);
746 continue;
747 }
748 for (const User *U : F.users()) {
749 if (const Constant *C = dyn_cast<Constant>(Val: U)) {
750 if (usedInGlobalVarDef(C)) {
751 // The use is in the initialization of a global variable
752 // that is a function pointer, so print a declaration
753 // for the original function
754 emitDeclaration(F: &F, O);
755 break;
756 }
757 // Emit a declaration of this function if the function that
758 // uses this constant expr has already been seen.
759 if (useFuncSeen(C, seenMap)) {
760 emitDeclaration(F: &F, O);
761 break;
762 }
763 }
764
765 if (!isa<Instruction>(Val: U))
766 continue;
767 const Instruction *instr = cast<Instruction>(Val: U);
768 const BasicBlock *bb = instr->getParent();
769 if (!bb)
770 continue;
771 const Function *caller = bb->getParent();
772 if (!caller)
773 continue;
774
775 // If a caller has already been seen, then the caller is
776 // appearing in the module before the callee. so print out
777 // a declaration for the callee.
778 if (seenMap.contains(Val: caller)) {
779 emitDeclaration(F: &F, O);
780 break;
781 }
782 }
783 seenMap[&F] = true;
784 }
785 for (const GlobalAlias &GA : M.aliases())
786 emitAliasDeclaration(GA: &GA, O);
787}
788
789static bool isEmptyXXStructor(GlobalVariable *GV) {
790 if (!GV) return true;
791 const ConstantArray *InitList = dyn_cast<ConstantArray>(Val: GV->getInitializer());
792 if (!InitList) return true; // Not an array; we don't know how to parse.
793 return InitList->getNumOperands() == 0;
794}
795
796void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
797 // Construct a default subtarget off of the TargetMachine defaults. The
798 // rest of NVPTX isn't friendly to change subtargets per function and
799 // so the default TargetMachine will have all of the options.
800 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
801 const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
802 SmallString<128> Str1;
803 raw_svector_ostream OS1(Str1);
804
805 // Emit header before any dwarf directives are emitted below.
806 emitHeader(M, O&: OS1, STI: *STI);
807 OutStreamer->emitRawText(String: OS1.str());
808}
809
810bool NVPTXAsmPrinter::doInitialization(Module &M) {
811 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
812 const NVPTXSubtarget &STI =
813 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
814 if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
815 report_fatal_error(reason: ".alias requires PTX version >= 6.3 and sm_30");
816
817 // OpenMP supports NVPTX global constructors and destructors.
818 bool IsOpenMP = M.getModuleFlag(Key: "openmp") != nullptr;
819
820 if (!isEmptyXXStructor(GV: M.getNamedGlobal(Name: "llvm.global_ctors")) &&
821 !LowerCtorDtor && !IsOpenMP) {
822 report_fatal_error(
823 reason: "Module has a nontrivial global ctor, which NVPTX does not support.");
824 return true; // error
825 }
826 if (!isEmptyXXStructor(GV: M.getNamedGlobal(Name: "llvm.global_dtors")) &&
827 !LowerCtorDtor && !IsOpenMP) {
828 report_fatal_error(
829 reason: "Module has a nontrivial global dtor, which NVPTX does not support.");
830 return true; // error
831 }
832
833 // We need to call the parent's one explicitly.
834 bool Result = AsmPrinter::doInitialization(M);
835
836 GlobalsEmitted = false;
837
838 return Result;
839}
840
841void NVPTXAsmPrinter::emitGlobals(const Module &M) {
842 SmallString<128> Str2;
843 raw_svector_ostream OS2(Str2);
844
845 emitDeclarations(M, O&: OS2);
846
847 // As ptxas does not support forward references of globals, we need to first
848 // sort the list of module-level globals in def-use order. We visit each
849 // global variable in order, and ensure that we emit it *after* its dependent
850 // globals. We use a little extra memory maintaining both a set and a list to
851 // have fast searches while maintaining a strict ordering.
852 SmallVector<const GlobalVariable *, 8> Globals;
853 DenseSet<const GlobalVariable *> GVVisited;
854 DenseSet<const GlobalVariable *> GVVisiting;
855
856 // Visit each global variable, in order
857 for (const GlobalVariable &I : M.globals())
858 VisitGlobalVariableForEmission(GV: &I, Order&: Globals, Visited&: GVVisited, Visiting&: GVVisiting);
859
860 assert(GVVisited.size() == M.global_size() && "Missed a global variable");
861 assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
862
863 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
864 const NVPTXSubtarget &STI =
865 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
866
867 // Print out module-level global variables in proper order
868 for (unsigned i = 0, e = Globals.size(); i != e; ++i)
869 printModuleLevelGV(GVar: Globals[i], O&: OS2, /*processDemoted=*/false, STI);
870
871 OS2 << '\n';
872
873 OutStreamer->emitRawText(String: OS2.str());
874}
875
876void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
877 SmallString<128> Str;
878 raw_svector_ostream OS(Str);
879
880 MCSymbol *Name = getSymbol(GV: &GA);
881
882 OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
883 << ";\n";
884
885 OutStreamer->emitRawText(String: OS.str());
886}
887
888void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
889 const NVPTXSubtarget &STI) {
890 O << "//\n";
891 O << "// Generated by LLVM NVPTX Back-End\n";
892 O << "//\n";
893 O << "\n";
894
895 unsigned PTXVersion = STI.getPTXVersion();
896 O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
897
898 O << ".target ";
899 O << STI.getTargetName();
900
901 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
902 if (NTM.getDrvInterface() == NVPTX::NVCL)
903 O << ", texmode_independent";
904
905 bool HasFullDebugInfo = false;
906 for (DICompileUnit *CU : M.debug_compile_units()) {
907 switch(CU->getEmissionKind()) {
908 case DICompileUnit::NoDebug:
909 case DICompileUnit::DebugDirectivesOnly:
910 break;
911 case DICompileUnit::LineTablesOnly:
912 case DICompileUnit::FullDebug:
913 HasFullDebugInfo = true;
914 break;
915 }
916 if (HasFullDebugInfo)
917 break;
918 }
919 if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
920 O << ", debug";
921
922 O << "\n";
923
924 O << ".address_size ";
925 if (NTM.is64Bit())
926 O << "64";
927 else
928 O << "32";
929 O << "\n";
930
931 O << "\n";
932}
933
934bool NVPTXAsmPrinter::doFinalization(Module &M) {
935 bool HasDebugInfo = MMI && MMI->hasDebugInfo();
936
937 // If we did not emit any functions, then the global declarations have not
938 // yet been emitted.
939 if (!GlobalsEmitted) {
940 emitGlobals(M);
941 GlobalsEmitted = true;
942 }
943
944 // call doFinalization
945 bool ret = AsmPrinter::doFinalization(M);
946
947 clearAnnotationCache(&M);
948
949 auto *TS =
950 static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
951 // Close the last emitted section
952 if (HasDebugInfo) {
953 TS->closeLastSection();
954 // Emit empty .debug_loc section for better support of the empty files.
955 OutStreamer->emitRawText(String: "\t.section\t.debug_loc\t{\t}");
956 }
957
958 // Output last DWARF .file directives, if any.
959 TS->outputDwarfFileDirectives();
960
961 return ret;
962}
963
964// This function emits appropriate linkage directives for
965// functions and global variables.
966//
967// extern function declaration -> .extern
968// extern function definition -> .visible
969// external global variable with init -> .visible
970// external without init -> .extern
971// appending -> not allowed, assert.
972// for any linkage other than
973// internal, private, linker_private,
974// linker_private_weak, linker_private_weak_def_auto,
975// we emit -> .weak.
976
977void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
978 raw_ostream &O) {
979 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
980 if (V->hasExternalLinkage()) {
981 if (isa<GlobalVariable>(Val: V)) {
982 const GlobalVariable *GVar = cast<GlobalVariable>(Val: V);
983 if (GVar) {
984 if (GVar->hasInitializer())
985 O << ".visible ";
986 else
987 O << ".extern ";
988 }
989 } else if (V->isDeclaration())
990 O << ".extern ";
991 else
992 O << ".visible ";
993 } else if (V->hasAppendingLinkage()) {
994 std::string msg;
995 msg.append(s: "Error: ");
996 msg.append(s: "Symbol ");
997 if (V->hasName())
998 msg.append(str: std::string(V->getName()));
999 msg.append(s: "has unsupported appending linkage type");
1000 llvm_unreachable(msg.c_str());
1001 } else if (!V->hasInternalLinkage() &&
1002 !V->hasPrivateLinkage()) {
1003 O << ".weak ";
1004 }
1005 }
1006}
1007
1008void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1009 raw_ostream &O, bool processDemoted,
1010 const NVPTXSubtarget &STI) {
1011 // Skip meta data
1012 if (GVar->hasSection()) {
1013 if (GVar->getSection() == "llvm.metadata")
1014 return;
1015 }
1016
1017 // Skip LLVM intrinsic global variables
1018 if (GVar->getName().starts_with(Prefix: "llvm.") ||
1019 GVar->getName().starts_with(Prefix: "nvvm."))
1020 return;
1021
1022 const DataLayout &DL = getDataLayout();
1023
1024 // GlobalVariables are always constant pointers themselves.
1025 Type *ETy = GVar->getValueType();
1026
1027 if (GVar->hasExternalLinkage()) {
1028 if (GVar->hasInitializer())
1029 O << ".visible ";
1030 else
1031 O << ".extern ";
1032 } else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
1033 GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
1034 O << ".common ";
1035 } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1036 GVar->hasAvailableExternallyLinkage() ||
1037 GVar->hasCommonLinkage()) {
1038 O << ".weak ";
1039 }
1040
1041 if (isTexture(*GVar)) {
1042 O << ".global .texref " << getTextureName(*GVar) << ";\n";
1043 return;
1044 }
1045
1046 if (isSurface(*GVar)) {
1047 O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1048 return;
1049 }
1050
1051 if (GVar->isDeclaration()) {
1052 // (extern) declarations, no definition or initializer
1053 // Currently the only known declaration is for an automatic __local
1054 // (.shared) promoted to global.
1055 emitPTXGlobalVariable(GVar, O, STI);
1056 O << ";\n";
1057 return;
1058 }
1059
1060 if (isSampler(*GVar)) {
1061 O << ".global .samplerref " << getSamplerName(*GVar);
1062
1063 const Constant *Initializer = nullptr;
1064 if (GVar->hasInitializer())
1065 Initializer = GVar->getInitializer();
1066 const ConstantInt *CI = nullptr;
1067 if (Initializer)
1068 CI = dyn_cast<ConstantInt>(Val: Initializer);
1069 if (CI) {
1070 unsigned sample = CI->getZExtValue();
1071
1072 O << " = { ";
1073
1074 for (int i = 0,
1075 addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1076 i < 3; i++) {
1077 O << "addr_mode_" << i << " = ";
1078 switch (addr) {
1079 case 0:
1080 O << "wrap";
1081 break;
1082 case 1:
1083 O << "clamp_to_border";
1084 break;
1085 case 2:
1086 O << "clamp_to_edge";
1087 break;
1088 case 3:
1089 O << "wrap";
1090 break;
1091 case 4:
1092 O << "mirror";
1093 break;
1094 }
1095 O << ", ";
1096 }
1097 O << "filter_mode = ";
1098 switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1099 case 0:
1100 O << "nearest";
1101 break;
1102 case 1:
1103 O << "linear";
1104 break;
1105 case 2:
1106 llvm_unreachable("Anisotropic filtering is not supported");
1107 default:
1108 O << "nearest";
1109 break;
1110 }
1111 if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1112 O << ", force_unnormalized_coords = 1";
1113 }
1114 O << " }";
1115 }
1116
1117 O << ";\n";
1118 return;
1119 }
1120
1121 if (GVar->hasPrivateLinkage()) {
1122 if (strncmp(s1: GVar->getName().data(), s2: "unrollpragma", n: 12) == 0)
1123 return;
1124
1125 // FIXME - need better way (e.g. Metadata) to avoid generating this global
1126 if (strncmp(s1: GVar->getName().data(), s2: "filename", n: 8) == 0)
1127 return;
1128 if (GVar->use_empty())
1129 return;
1130 }
1131
1132 const Function *demotedFunc = nullptr;
1133 if (!processDemoted && canDemoteGlobalVar(gv: GVar, f&: demotedFunc)) {
1134 O << "// " << GVar->getName() << " has been demoted\n";
1135 if (localDecls.find(x: demotedFunc) != localDecls.end())
1136 localDecls[demotedFunc].push_back(x: GVar);
1137 else {
1138 std::vector<const GlobalVariable *> temp;
1139 temp.push_back(x: GVar);
1140 localDecls[demotedFunc] = temp;
1141 }
1142 return;
1143 }
1144
1145 O << ".";
1146 emitPTXAddressSpace(AddressSpace: GVar->getAddressSpace(), O);
1147
1148 if (isManaged(*GVar)) {
1149 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1150 report_fatal_error(
1151 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1152 }
1153 O << " .attribute(.managed)";
1154 }
1155
1156 if (MaybeAlign A = GVar->getAlign())
1157 O << " .align " << A->value();
1158 else
1159 O << " .align " << (int)DL.getPrefTypeAlign(Ty: ETy).value();
1160
1161 if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1162 (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1163 O << " .";
1164 // Special case: ABI requires that we use .u8 for predicates
1165 if (ETy->isIntegerTy(Bitwidth: 1))
1166 O << "u8";
1167 else
1168 O << getPTXFundamentalTypeStr(Ty: ETy, false);
1169 O << " ";
1170 getSymbol(GV: GVar)->print(OS&: O, MAI);
1171
1172 // Ptx allows variable initilization only for constant and global state
1173 // spaces.
1174 if (GVar->hasInitializer()) {
1175 if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1176 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1177 const Constant *Initializer = GVar->getInitializer();
1178 // 'undef' is treated as there is no value specified.
1179 if (!Initializer->isNullValue() && !isa<UndefValue>(Val: Initializer)) {
1180 O << " = ";
1181 printScalarConstant(CPV: Initializer, O);
1182 }
1183 } else {
1184 // The frontend adds zero-initializer to device and constant variables
1185 // that don't have an initial value, and UndefValue to shared
1186 // variables, so skip warning for this case.
1187 if (!GVar->getInitializer()->isNullValue() &&
1188 !isa<UndefValue>(Val: GVar->getInitializer())) {
1189 report_fatal_error(reason: "initial value of '" + GVar->getName() +
1190 "' is not allowed in addrspace(" +
1191 Twine(GVar->getAddressSpace()) + ")");
1192 }
1193 }
1194 }
1195 } else {
1196 uint64_t ElementSize = 0;
1197
1198 // Although PTX has direct support for struct type and array type and
1199 // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1200 // targets that support these high level field accesses. Structs, arrays
1201 // and vectors are lowered into arrays of bytes.
1202 switch (ETy->getTypeID()) {
1203 case Type::IntegerTyID: // Integers larger than 64 bits
1204 case Type::StructTyID:
1205 case Type::ArrayTyID:
1206 case Type::FixedVectorTyID:
1207 ElementSize = DL.getTypeStoreSize(Ty: ETy);
1208 // Ptx allows variable initilization only for constant and
1209 // global state spaces.
1210 if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1211 (GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1212 GVar->hasInitializer()) {
1213 const Constant *Initializer = GVar->getInitializer();
1214 if (!isa<UndefValue>(Val: Initializer) && !Initializer->isNullValue()) {
1215 AggBuffer aggBuffer(ElementSize, *this);
1216 bufferAggregateConstant(CV: Initializer, aggBuffer: &aggBuffer);
1217 if (aggBuffer.numSymbols()) {
1218 unsigned int ptrSize = MAI->getCodePointerSize();
1219 if (ElementSize % ptrSize ||
1220 !aggBuffer.allSymbolsAligned(ptrSize)) {
1221 // Print in bytes and use the mask() operator for pointers.
1222 if (!STI.hasMaskOperator())
1223 report_fatal_error(
1224 reason: "initialized packed aggregate with pointers '" +
1225 GVar->getName() +
1226 "' requires at least PTX ISA version 7.1");
1227 O << " .u8 ";
1228 getSymbol(GV: GVar)->print(OS&: O, MAI);
1229 O << "[" << ElementSize << "] = {";
1230 aggBuffer.printBytes(os&: O);
1231 O << "}";
1232 } else {
1233 O << " .u" << ptrSize * 8 << " ";
1234 getSymbol(GV: GVar)->print(OS&: O, MAI);
1235 O << "[" << ElementSize / ptrSize << "] = {";
1236 aggBuffer.printWords(os&: O);
1237 O << "}";
1238 }
1239 } else {
1240 O << " .b8 ";
1241 getSymbol(GV: GVar)->print(OS&: O, MAI);
1242 O << "[" << ElementSize << "] = {";
1243 aggBuffer.printBytes(os&: O);
1244 O << "}";
1245 }
1246 } else {
1247 O << " .b8 ";
1248 getSymbol(GV: GVar)->print(OS&: O, MAI);
1249 if (ElementSize) {
1250 O << "[";
1251 O << ElementSize;
1252 O << "]";
1253 }
1254 }
1255 } else {
1256 O << " .b8 ";
1257 getSymbol(GV: GVar)->print(OS&: O, MAI);
1258 if (ElementSize) {
1259 O << "[";
1260 O << ElementSize;
1261 O << "]";
1262 }
1263 }
1264 break;
1265 default:
1266 llvm_unreachable("type not supported yet");
1267 }
1268 }
1269 O << ";\n";
1270}
1271
1272void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1273 const Value *v = Symbols[nSym];
1274 const Value *v0 = SymbolsBeforeStripping[nSym];
1275 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: v)) {
1276 MCSymbol *Name = AP.getSymbol(GV: GVar);
1277 PointerType *PTy = dyn_cast<PointerType>(Val: v0->getType());
1278 // Is v0 a generic pointer?
1279 bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1280 if (EmitGeneric && isGenericPointer && !isa<Function>(Val: v)) {
1281 os << "generic(";
1282 Name->print(OS&: os, MAI: AP.MAI);
1283 os << ")";
1284 } else {
1285 Name->print(OS&: os, MAI: AP.MAI);
1286 }
1287 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(Val: v0)) {
1288 const MCExpr *Expr = AP.lowerConstantForGV(CV: cast<Constant>(Val: CExpr), ProcessingGeneric: false);
1289 AP.printMCExpr(Expr: *Expr, OS&: os);
1290 } else
1291 llvm_unreachable("symbol type unknown");
1292}
1293
1294void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1295 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1296 // Do not emit trailing zero initializers. They will be zero-initialized by
1297 // ptxas. This saves on both space requirements for the generated PTX and on
1298 // memory use by ptxas. (See:
1299 // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1300 unsigned int InitializerCount = size;
1301 // TODO: symbols make this harder, but it would still be good to trim trailing
1302 // 0s for aggs with symbols as well.
1303 if (numSymbols() == 0)
1304 while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1305 InitializerCount--;
1306
1307 symbolPosInBuffer.push_back(Elt: InitializerCount);
1308 unsigned int nSym = 0;
1309 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1310 for (unsigned int pos = 0; pos < InitializerCount;) {
1311 if (pos)
1312 os << ", ";
1313 if (pos != nextSymbolPos) {
1314 os << (unsigned int)buffer[pos];
1315 ++pos;
1316 continue;
1317 }
1318 // Generate a per-byte mask() operator for the symbol, which looks like:
1319 // .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1320 // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1321 std::string symText;
1322 llvm::raw_string_ostream oss(symText);
1323 printSymbol(nSym, os&: oss);
1324 for (unsigned i = 0; i < ptrSize; ++i) {
1325 if (i)
1326 os << ", ";
1327 llvm::write_hex(S&: os, N: 0xFFULL << i * 8, Style: HexPrintStyle::PrefixUpper);
1328 os << "(" << symText << ")";
1329 }
1330 pos += ptrSize;
1331 nextSymbolPos = symbolPosInBuffer[++nSym];
1332 assert(nextSymbolPos >= pos);
1333 }
1334}
1335
1336void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1337 unsigned int ptrSize = AP.MAI->getCodePointerSize();
1338 symbolPosInBuffer.push_back(Elt: size);
1339 unsigned int nSym = 0;
1340 unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1341 assert(nextSymbolPos % ptrSize == 0);
1342 for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1343 if (pos)
1344 os << ", ";
1345 if (pos == nextSymbolPos) {
1346 printSymbol(nSym, os);
1347 nextSymbolPos = symbolPosInBuffer[++nSym];
1348 assert(nextSymbolPos % ptrSize == 0);
1349 assert(nextSymbolPos >= pos + ptrSize);
1350 } else if (ptrSize == 4)
1351 os << support::endian::read32le(P: &buffer[pos]);
1352 else
1353 os << support::endian::read64le(P: &buffer[pos]);
1354 }
1355}
1356
1357void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1358 if (localDecls.find(x: f) == localDecls.end())
1359 return;
1360
1361 std::vector<const GlobalVariable *> &gvars = localDecls[f];
1362
1363 const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1364 const NVPTXSubtarget &STI =
1365 *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1366
1367 for (const GlobalVariable *GV : gvars) {
1368 O << "\t// demoted variable\n\t";
1369 printModuleLevelGV(GVar: GV, O, /*processDemoted=*/true, STI);
1370 }
1371}
1372
1373void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1374 raw_ostream &O) const {
1375 switch (AddressSpace) {
1376 case ADDRESS_SPACE_LOCAL:
1377 O << "local";
1378 break;
1379 case ADDRESS_SPACE_GLOBAL:
1380 O << "global";
1381 break;
1382 case ADDRESS_SPACE_CONST:
1383 O << "const";
1384 break;
1385 case ADDRESS_SPACE_SHARED:
1386 O << "shared";
1387 break;
1388 default:
1389 report_fatal_error(reason: "Bad address space found while emitting PTX: " +
1390 llvm::Twine(AddressSpace));
1391 break;
1392 }
1393}
1394
1395std::string
1396NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1397 switch (Ty->getTypeID()) {
1398 case Type::IntegerTyID: {
1399 unsigned NumBits = cast<IntegerType>(Val: Ty)->getBitWidth();
1400 if (NumBits == 1)
1401 return "pred";
1402 else if (NumBits <= 64) {
1403 std::string name = "u";
1404 return name + utostr(X: NumBits);
1405 } else {
1406 llvm_unreachable("Integer too large");
1407 break;
1408 }
1409 break;
1410 }
1411 case Type::BFloatTyID:
1412 case Type::HalfTyID:
1413 // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1414 // PTX assembly.
1415 return "b16";
1416 case Type::FloatTyID:
1417 return "f32";
1418 case Type::DoubleTyID:
1419 return "f64";
1420 case Type::PointerTyID: {
1421 unsigned PtrSize = TM.getPointerSizeInBits(AS: Ty->getPointerAddressSpace());
1422 assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1423
1424 if (PtrSize == 64)
1425 if (useB4PTR)
1426 return "b64";
1427 else
1428 return "u64";
1429 else if (useB4PTR)
1430 return "b32";
1431 else
1432 return "u32";
1433 }
1434 default:
1435 break;
1436 }
1437 llvm_unreachable("unexpected type");
1438}
1439
1440void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1441 raw_ostream &O,
1442 const NVPTXSubtarget &STI) {
1443 const DataLayout &DL = getDataLayout();
1444
1445 // GlobalVariables are always constant pointers themselves.
1446 Type *ETy = GVar->getValueType();
1447
1448 O << ".";
1449 emitPTXAddressSpace(AddressSpace: GVar->getType()->getAddressSpace(), O);
1450 if (isManaged(*GVar)) {
1451 if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1452 report_fatal_error(
1453 reason: ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1454 }
1455 O << " .attribute(.managed)";
1456 }
1457 if (MaybeAlign A = GVar->getAlign())
1458 O << " .align " << A->value();
1459 else
1460 O << " .align " << (int)DL.getPrefTypeAlign(Ty: ETy).value();
1461
1462 // Special case for i128
1463 if (ETy->isIntegerTy(Bitwidth: 128)) {
1464 O << " .b8 ";
1465 getSymbol(GV: GVar)->print(OS&: O, MAI);
1466 O << "[16]";
1467 return;
1468 }
1469
1470 if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1471 O << " .";
1472 O << getPTXFundamentalTypeStr(Ty: ETy);
1473 O << " ";
1474 getSymbol(GV: GVar)->print(OS&: O, MAI);
1475 return;
1476 }
1477
1478 int64_t ElementSize = 0;
1479
1480 // Although PTX has direct support for struct type and array type and LLVM IR
1481 // is very similar to PTX, the LLVM CodeGen does not support for targets that
1482 // support these high level field accesses. Structs and arrays are lowered
1483 // into arrays of bytes.
1484 switch (ETy->getTypeID()) {
1485 case Type::StructTyID:
1486 case Type::ArrayTyID:
1487 case Type::FixedVectorTyID:
1488 ElementSize = DL.getTypeStoreSize(Ty: ETy);
1489 O << " .b8 ";
1490 getSymbol(GV: GVar)->print(OS&: O, MAI);
1491 O << "[";
1492 if (ElementSize) {
1493 O << ElementSize;
1494 }
1495 O << "]";
1496 break;
1497 default:
1498 llvm_unreachable("type not supported yet");
1499 }
1500}
1501
1502void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1503 const DataLayout &DL = getDataLayout();
1504 const AttributeList &PAL = F->getAttributes();
1505 const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(F: *F);
1506 const auto *TLI = cast<NVPTXTargetLowering>(Val: STI.getTargetLowering());
1507
1508 Function::const_arg_iterator I, E;
1509 unsigned paramIndex = 0;
1510 bool first = true;
1511 bool isKernelFunc = isKernelFunction(*F);
1512 bool isABI = (STI.getSmVersion() >= 20);
1513 bool hasImageHandles = STI.hasImageHandles();
1514
1515 if (F->arg_empty() && !F->isVarArg()) {
1516 O << "()";
1517 return;
1518 }
1519
1520 O << "(\n";
1521
1522 for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1523 Type *Ty = I->getType();
1524
1525 if (!first)
1526 O << ",\n";
1527
1528 first = false;
1529
1530 // Handle image/sampler parameters
1531 if (isKernelFunction(*F)) {
1532 if (isSampler(*I) || isImage(*I)) {
1533 if (isImage(*I)) {
1534 if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1535 if (hasImageHandles)
1536 O << "\t.param .u64 .ptr .surfref ";
1537 else
1538 O << "\t.param .surfref ";
1539 O << TLI->getParamName(F, Idx: paramIndex);
1540 }
1541 else { // Default image is read_only
1542 if (hasImageHandles)
1543 O << "\t.param .u64 .ptr .texref ";
1544 else
1545 O << "\t.param .texref ";
1546 O << TLI->getParamName(F, Idx: paramIndex);
1547 }
1548 } else {
1549 if (hasImageHandles)
1550 O << "\t.param .u64 .ptr .samplerref ";
1551 else
1552 O << "\t.param .samplerref ";
1553 O << TLI->getParamName(F, Idx: paramIndex);
1554 }
1555 continue;
1556 }
1557 }
1558
1559 auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1560 paramIndex](Type *Ty) -> Align {
1561 Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, ArgTy: Ty, DL);
1562 MaybeAlign ParamAlign = PAL.getParamAlignment(ArgNo: paramIndex);
1563 return std::max(a: TypeAlign, b: ParamAlign.valueOrOne());
1564 };
1565
1566 if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1567 if (ShouldPassAsArray(Ty)) {
1568 // Just print .param .align <a> .b8 .param[size];
1569 // <a> = optimal alignment for the element type; always multiple of
1570 // PAL.getParamAlignment
1571 // size = typeallocsize of element type
1572 Align OptimalAlign = getOptimalAlignForParam(Ty);
1573
1574 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1575 O << TLI->getParamName(F, Idx: paramIndex);
1576 O << "[" << DL.getTypeAllocSize(Ty) << "]";
1577
1578 continue;
1579 }
1580 // Just a scalar
1581 auto *PTy = dyn_cast<PointerType>(Val: Ty);
1582 unsigned PTySizeInBits = 0;
1583 if (PTy) {
1584 PTySizeInBits =
1585 TLI->getPointerTy(DL, AS: PTy->getAddressSpace()).getSizeInBits();
1586 assert(PTySizeInBits && "Invalid pointer size");
1587 }
1588
1589 if (isKernelFunc) {
1590 if (PTy) {
1591 // Special handling for pointer arguments to kernel
1592 O << "\t.param .u" << PTySizeInBits << " ";
1593
1594 if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1595 NVPTX::CUDA) {
1596 int addrSpace = PTy->getAddressSpace();
1597 switch (addrSpace) {
1598 default:
1599 O << ".ptr ";
1600 break;
1601 case ADDRESS_SPACE_CONST:
1602 O << ".ptr .const ";
1603 break;
1604 case ADDRESS_SPACE_SHARED:
1605 O << ".ptr .shared ";
1606 break;
1607 case ADDRESS_SPACE_GLOBAL:
1608 O << ".ptr .global ";
1609 break;
1610 }
1611 Align ParamAlign = I->getParamAlign().valueOrOne();
1612 O << ".align " << ParamAlign.value() << " ";
1613 }
1614 O << TLI->getParamName(F, Idx: paramIndex);
1615 continue;
1616 }
1617
1618 // non-pointer scalar to kernel func
1619 O << "\t.param .";
1620 // Special case: predicate operands become .u8 types
1621 if (Ty->isIntegerTy(Bitwidth: 1))
1622 O << "u8";
1623 else
1624 O << getPTXFundamentalTypeStr(Ty);
1625 O << " ";
1626 O << TLI->getParamName(F, Idx: paramIndex);
1627 continue;
1628 }
1629 // Non-kernel function, just print .param .b<size> for ABI
1630 // and .reg .b<size> for non-ABI
1631 unsigned sz = 0;
1632 if (isa<IntegerType>(Val: Ty)) {
1633 sz = cast<IntegerType>(Val: Ty)->getBitWidth();
1634 sz = promoteScalarArgumentSize(size: sz);
1635 } else if (PTy) {
1636 assert(PTySizeInBits && "Invalid pointer size");
1637 sz = PTySizeInBits;
1638 } else
1639 sz = Ty->getPrimitiveSizeInBits();
1640 if (isABI)
1641 O << "\t.param .b" << sz << " ";
1642 else
1643 O << "\t.reg .b" << sz << " ";
1644 O << TLI->getParamName(F, Idx: paramIndex);
1645 continue;
1646 }
1647
1648 // param has byVal attribute.
1649 Type *ETy = PAL.getParamByValType(ArgNo: paramIndex);
1650 assert(ETy && "Param should have byval type");
1651
1652 if (isABI || isKernelFunc) {
1653 // Just print .param .align <a> .b8 .param[size];
1654 // <a> = optimal alignment for the element type; always multiple of
1655 // PAL.getParamAlignment
1656 // size = typeallocsize of element type
1657 Align OptimalAlign =
1658 isKernelFunc
1659 ? getOptimalAlignForParam(ETy)
1660 : TLI->getFunctionByValParamAlign(
1661 F, ArgTy: ETy, InitialAlign: PAL.getParamAlignment(ArgNo: paramIndex).valueOrOne(), DL);
1662
1663 unsigned sz = DL.getTypeAllocSize(Ty: ETy);
1664 O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1665 O << TLI->getParamName(F, Idx: paramIndex);
1666 O << "[" << sz << "]";
1667 continue;
1668 } else {
1669 // Split the ETy into constituent parts and
1670 // print .param .b<size> <name> for each part.
1671 // Further, if a part is vector, print the above for
1672 // each vector element.
1673 SmallVector<EVT, 16> vtparts;
1674 ComputeValueVTs(TLI: *TLI, DL, Ty: ETy, ValueVTs&: vtparts);
1675 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1676 unsigned elems = 1;
1677 EVT elemtype = vtparts[i];
1678 if (vtparts[i].isVector()) {
1679 elems = vtparts[i].getVectorNumElements();
1680 elemtype = vtparts[i].getVectorElementType();
1681 }
1682
1683 for (unsigned j = 0, je = elems; j != je; ++j) {
1684 unsigned sz = elemtype.getSizeInBits();
1685 if (elemtype.isInteger())
1686 sz = promoteScalarArgumentSize(size: sz);
1687 O << "\t.reg .b" << sz << " ";
1688 O << TLI->getParamName(F, Idx: paramIndex);
1689 if (j < je - 1)
1690 O << ",\n";
1691 ++paramIndex;
1692 }
1693 if (i < e - 1)
1694 O << ",\n";
1695 }
1696 --paramIndex;
1697 continue;
1698 }
1699 }
1700
1701 if (F->isVarArg()) {
1702 if (!first)
1703 O << ",\n";
1704 O << "\t.param .align " << STI.getMaxRequiredAlignment();
1705 O << " .b8 ";
1706 O << TLI->getParamName(F, /* vararg */ Idx: -1) << "[]";
1707 }
1708
1709 O << "\n)";
1710}
1711
1712void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1713 const MachineFunction &MF) {
1714 SmallString<128> Str;
1715 raw_svector_ostream O(Str);
1716
1717 // Map the global virtual register number to a register class specific
1718 // virtual register number starting from 1 with that class.
1719 const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1720 //unsigned numRegClasses = TRI->getNumRegClasses();
1721
1722 // Emit the Fake Stack Object
1723 const MachineFrameInfo &MFI = MF.getFrameInfo();
1724 int64_t NumBytes = MFI.getStackSize();
1725 if (NumBytes) {
1726 O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1727 << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1728 if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1729 O << "\t.reg .b64 \t%SP;\n";
1730 O << "\t.reg .b64 \t%SPL;\n";
1731 } else {
1732 O << "\t.reg .b32 \t%SP;\n";
1733 O << "\t.reg .b32 \t%SPL;\n";
1734 }
1735 }
1736
1737 // Go through all virtual registers to establish the mapping between the
1738 // global virtual
1739 // register number and the per class virtual register number.
1740 // We use the per class virtual register number in the ptx output.
1741 unsigned int numVRs = MRI->getNumVirtRegs();
1742 for (unsigned i = 0; i < numVRs; i++) {
1743 Register vr = Register::index2VirtReg(Index: i);
1744 const TargetRegisterClass *RC = MRI->getRegClass(Reg: vr);
1745 DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1746 int n = regmap.size();
1747 regmap.insert(KV: std::make_pair(x&: vr, y: n + 1));
1748 }
1749
1750 // Emit register declarations
1751 // @TODO: Extract out the real register usage
1752 // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1753 // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1754 // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1755 // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1756 // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1757 // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1758 // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1759
1760 // Emit declaration of the virtual registers or 'physical' registers for
1761 // each register class
1762 for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1763 const TargetRegisterClass *RC = TRI->getRegClass(i);
1764 DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1765 std::string rcname = getNVPTXRegClassName(RC);
1766 std::string rcStr = getNVPTXRegClassStr(RC);
1767 int n = regmap.size();
1768
1769 // Only declare those registers that may be used.
1770 if (n) {
1771 O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1772 << ">;\n";
1773 }
1774 }
1775
1776 OutStreamer->emitRawText(String: O.str());
1777}
1778
1779void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1780 APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1781 bool ignored;
1782 unsigned int numHex;
1783 const char *lead;
1784
1785 if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1786 numHex = 8;
1787 lead = "0f";
1788 APF.convert(ToSemantics: APFloat::IEEEsingle(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1789 } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1790 numHex = 16;
1791 lead = "0d";
1792 APF.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven, losesInfo: &ignored);
1793 } else
1794 llvm_unreachable("unsupported fp type");
1795
1796 APInt API = APF.bitcastToAPInt();
1797 O << lead << format_hex_no_prefix(N: API.getZExtValue(), Width: numHex, /*Upper=*/true);
1798}
1799
1800void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1801 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1802 O << CI->getValue();
1803 return;
1804 }
1805 if (const ConstantFP *CFP = dyn_cast<ConstantFP>(Val: CPV)) {
1806 printFPConstant(Fp: CFP, O);
1807 return;
1808 }
1809 if (isa<ConstantPointerNull>(Val: CPV)) {
1810 O << "0";
1811 return;
1812 }
1813 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1814 bool IsNonGenericPointer = false;
1815 if (GVar->getType()->getAddressSpace() != 0) {
1816 IsNonGenericPointer = true;
1817 }
1818 if (EmitGeneric && !isa<Function>(Val: CPV) && !IsNonGenericPointer) {
1819 O << "generic(";
1820 getSymbol(GV: GVar)->print(OS&: O, MAI);
1821 O << ")";
1822 } else {
1823 getSymbol(GV: GVar)->print(OS&: O, MAI);
1824 }
1825 return;
1826 }
1827 if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1828 const MCExpr *E = lowerConstantForGV(CV: cast<Constant>(Val: Cexpr), ProcessingGeneric: false);
1829 printMCExpr(Expr: *E, OS&: O);
1830 return;
1831 }
1832 llvm_unreachable("Not scalar type found in printScalarConstant()");
1833}
1834
1835void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1836 AggBuffer *AggBuffer) {
1837 const DataLayout &DL = getDataLayout();
1838 int AllocSize = DL.getTypeAllocSize(Ty: CPV->getType());
1839 if (isa<UndefValue>(Val: CPV) || CPV->isNullValue()) {
1840 // Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1841 // only the space allocated by CPV.
1842 AggBuffer->addZeros(Num: Bytes ? Bytes : AllocSize);
1843 return;
1844 }
1845
1846 // Helper for filling AggBuffer with APInts.
1847 auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1848 size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1849 SmallVector<unsigned char, 16> Buf(NumBytes);
1850 for (unsigned I = 0; I < NumBytes; ++I) {
1851 Buf[I] = Val.extractBitsAsZExtValue(numBits: 8, bitPosition: I * 8);
1852 }
1853 AggBuffer->addBytes(Ptr: Buf.data(), Num: NumBytes, Bytes);
1854 };
1855
1856 switch (CPV->getType()->getTypeID()) {
1857 case Type::IntegerTyID:
1858 if (const auto CI = dyn_cast<ConstantInt>(Val: CPV)) {
1859 AddIntToBuffer(CI->getValue());
1860 break;
1861 }
1862 if (const auto *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1863 if (const auto *CI =
1864 dyn_cast<ConstantInt>(Val: ConstantFoldConstant(C: Cexpr, DL))) {
1865 AddIntToBuffer(CI->getValue());
1866 break;
1867 }
1868 if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1869 Value *V = Cexpr->getOperand(i_nocapture: 0)->stripPointerCasts();
1870 AggBuffer->addSymbol(GVar: V, GVarBeforeStripping: Cexpr->getOperand(i_nocapture: 0));
1871 AggBuffer->addZeros(Num: AllocSize);
1872 break;
1873 }
1874 }
1875 llvm_unreachable("unsupported integer const type");
1876 break;
1877
1878 case Type::HalfTyID:
1879 case Type::BFloatTyID:
1880 case Type::FloatTyID:
1881 case Type::DoubleTyID:
1882 AddIntToBuffer(cast<ConstantFP>(Val: CPV)->getValueAPF().bitcastToAPInt());
1883 break;
1884
1885 case Type::PointerTyID: {
1886 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(Val: CPV)) {
1887 AggBuffer->addSymbol(GVar, GVarBeforeStripping: GVar);
1888 } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(Val: CPV)) {
1889 const Value *v = Cexpr->stripPointerCasts();
1890 AggBuffer->addSymbol(GVar: v, GVarBeforeStripping: Cexpr);
1891 }
1892 AggBuffer->addZeros(Num: AllocSize);
1893 break;
1894 }
1895
1896 case Type::ArrayTyID:
1897 case Type::FixedVectorTyID:
1898 case Type::StructTyID: {
1899 if (isa<ConstantAggregate>(Val: CPV) || isa<ConstantDataSequential>(Val: CPV)) {
1900 bufferAggregateConstant(CV: CPV, aggBuffer: AggBuffer);
1901 if (Bytes > AllocSize)
1902 AggBuffer->addZeros(Num: Bytes - AllocSize);
1903 } else if (isa<ConstantAggregateZero>(Val: CPV))
1904 AggBuffer->addZeros(Num: Bytes);
1905 else
1906 llvm_unreachable("Unexpected Constant type");
1907 break;
1908 }
1909
1910 default:
1911 llvm_unreachable("unsupported type");
1912 }
1913}
1914
1915void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1916 AggBuffer *aggBuffer) {
1917 const DataLayout &DL = getDataLayout();
1918 int Bytes;
1919
1920 // Integers of arbitrary width
1921 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CPV)) {
1922 APInt Val = CI->getValue();
1923 for (unsigned I = 0, E = DL.getTypeAllocSize(Ty: CPV->getType()); I < E; ++I) {
1924 uint8_t Byte = Val.getLoBits(numBits: 8).getZExtValue();
1925 aggBuffer->addBytes(Ptr: &Byte, Num: 1, Bytes: 1);
1926 Val.lshrInPlace(ShiftAmt: 8);
1927 }
1928 return;
1929 }
1930
1931 // Old constants
1932 if (isa<ConstantArray>(Val: CPV) || isa<ConstantVector>(Val: CPV)) {
1933 if (CPV->getNumOperands())
1934 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1935 bufferLEByte(CPV: cast<Constant>(Val: CPV->getOperand(i)), Bytes: 0, AggBuffer: aggBuffer);
1936 return;
1937 }
1938
1939 if (const ConstantDataSequential *CDS =
1940 dyn_cast<ConstantDataSequential>(Val: CPV)) {
1941 if (CDS->getNumElements())
1942 for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1943 bufferLEByte(CPV: cast<Constant>(Val: CDS->getElementAsConstant(i)), Bytes: 0,
1944 AggBuffer: aggBuffer);
1945 return;
1946 }
1947
1948 if (isa<ConstantStruct>(Val: CPV)) {
1949 if (CPV->getNumOperands()) {
1950 StructType *ST = cast<StructType>(Val: CPV->getType());
1951 for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1952 if (i == (e - 1))
1953 Bytes = DL.getStructLayout(Ty: ST)->getElementOffset(Idx: 0) +
1954 DL.getTypeAllocSize(Ty: ST) -
1955 DL.getStructLayout(Ty: ST)->getElementOffset(Idx: i);
1956 else
1957 Bytes = DL.getStructLayout(Ty: ST)->getElementOffset(Idx: i + 1) -
1958 DL.getStructLayout(Ty: ST)->getElementOffset(Idx: i);
1959 bufferLEByte(CPV: cast<Constant>(Val: CPV->getOperand(i)), Bytes, AggBuffer: aggBuffer);
1960 }
1961 }
1962 return;
1963 }
1964 llvm_unreachable("unsupported constant type in printAggregateConstant()");
1965}
1966
1967/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1968/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1969/// expressions that are representable in PTX and create
1970/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1971const MCExpr *
1972NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1973 MCContext &Ctx = OutContext;
1974
1975 if (CV->isNullValue() || isa<UndefValue>(Val: CV))
1976 return MCConstantExpr::create(Value: 0, Ctx);
1977
1978 if (const ConstantInt *CI = dyn_cast<ConstantInt>(Val: CV))
1979 return MCConstantExpr::create(Value: CI->getZExtValue(), Ctx);
1980
1981 if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: CV)) {
1982 const MCSymbolRefExpr *Expr =
1983 MCSymbolRefExpr::create(Symbol: getSymbol(GV), Ctx);
1984 if (ProcessingGeneric) {
1985 return NVPTXGenericMCSymbolRefExpr::create(SymExpr: Expr, Ctx);
1986 } else {
1987 return Expr;
1988 }
1989 }
1990
1991 const ConstantExpr *CE = dyn_cast<ConstantExpr>(Val: CV);
1992 if (!CE) {
1993 llvm_unreachable("Unknown constant value to lower!");
1994 }
1995
1996 switch (CE->getOpcode()) {
1997 default:
1998 break; // Error
1999
2000 case Instruction::AddrSpaceCast: {
2001 // Strip the addrspacecast and pass along the operand
2002 PointerType *DstTy = cast<PointerType>(Val: CE->getType());
2003 if (DstTy->getAddressSpace() == 0)
2004 return lowerConstantForGV(CV: cast<const Constant>(Val: CE->getOperand(i_nocapture: 0)), ProcessingGeneric: true);
2005
2006 break; // Error
2007 }
2008
2009 case Instruction::GetElementPtr: {
2010 const DataLayout &DL = getDataLayout();
2011
2012 // Generate a symbolic expression for the byte address
2013 APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2014 cast<GEPOperator>(Val: CE)->accumulateConstantOffset(DL, Offset&: OffsetAI);
2015
2016 const MCExpr *Base = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0),
2017 ProcessingGeneric);
2018 if (!OffsetAI)
2019 return Base;
2020
2021 int64_t Offset = OffsetAI.getSExtValue();
2022 return MCBinaryExpr::createAdd(LHS: Base, RHS: MCConstantExpr::create(Value: Offset, Ctx),
2023 Ctx);
2024 }
2025
2026 case Instruction::Trunc:
2027 // We emit the value and depend on the assembler to truncate the generated
2028 // expression properly. This is important for differences between
2029 // blockaddress labels. Since the two labels are in the same function, it
2030 // is reasonable to treat their delta as a 32-bit value.
2031 [[fallthrough]];
2032 case Instruction::BitCast:
2033 return lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
2034
2035 case Instruction::IntToPtr: {
2036 const DataLayout &DL = getDataLayout();
2037
2038 // Handle casts to pointers by changing them into casts to the appropriate
2039 // integer type. This promotes constant folding and simplifies this code.
2040 Constant *Op = CE->getOperand(i_nocapture: 0);
2041 Op = ConstantFoldIntegerCast(C: Op, DestTy: DL.getIntPtrType(CV->getType()),
2042 /*IsSigned*/ false, DL);
2043 if (Op)
2044 return lowerConstantForGV(CV: Op, ProcessingGeneric);
2045
2046 break; // Error
2047 }
2048
2049 case Instruction::PtrToInt: {
2050 const DataLayout &DL = getDataLayout();
2051
2052 // Support only foldable casts to/from pointers that can be eliminated by
2053 // changing the pointer to the appropriately sized integer type.
2054 Constant *Op = CE->getOperand(i_nocapture: 0);
2055 Type *Ty = CE->getType();
2056
2057 const MCExpr *OpExpr = lowerConstantForGV(CV: Op, ProcessingGeneric);
2058
2059 // We can emit the pointer value into this slot if the slot is an
2060 // integer slot equal to the size of the pointer.
2061 if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Ty: Op->getType()))
2062 return OpExpr;
2063
2064 // Otherwise the pointer is smaller than the resultant integer, mask off
2065 // the high bits so we are sure to get a proper truncation if the input is
2066 // a constant expr.
2067 unsigned InBits = DL.getTypeAllocSizeInBits(Ty: Op->getType());
2068 const MCExpr *MaskExpr = MCConstantExpr::create(Value: ~0ULL >> (64-InBits), Ctx);
2069 return MCBinaryExpr::createAnd(LHS: OpExpr, RHS: MaskExpr, Ctx);
2070 }
2071
2072 // The MC library also has a right-shift operator, but it isn't consistently
2073 // signed or unsigned between different targets.
2074 case Instruction::Add: {
2075 const MCExpr *LHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 0), ProcessingGeneric);
2076 const MCExpr *RHS = lowerConstantForGV(CV: CE->getOperand(i_nocapture: 1), ProcessingGeneric);
2077 switch (CE->getOpcode()) {
2078 default: llvm_unreachable("Unknown binary operator constant cast expr");
2079 case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2080 }
2081 }
2082 }
2083
2084 // If the code isn't optimized, there may be outstanding folding
2085 // opportunities. Attempt to fold the expression using DataLayout as a
2086 // last resort before giving up.
2087 Constant *C = ConstantFoldConstant(C: CE, DL: getDataLayout());
2088 if (C != CE)
2089 return lowerConstantForGV(CV: C, ProcessingGeneric);
2090
2091 // Otherwise report the problem to the user.
2092 std::string S;
2093 raw_string_ostream OS(S);
2094 OS << "Unsupported expression in static initializer: ";
2095 CE->printAsOperand(O&: OS, /*PrintType=*/false,
2096 M: !MF ? nullptr : MF->getFunction().getParent());
2097 report_fatal_error(reason: Twine(OS.str()));
2098}
2099
2100// Copy of MCExpr::print customized for NVPTX
2101void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2102 switch (Expr.getKind()) {
2103 case MCExpr::Target:
2104 return cast<MCTargetExpr>(Val: &Expr)->printImpl(OS, MAI);
2105 case MCExpr::Constant:
2106 OS << cast<MCConstantExpr>(Val: Expr).getValue();
2107 return;
2108
2109 case MCExpr::SymbolRef: {
2110 const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Val: Expr);
2111 const MCSymbol &Sym = SRE.getSymbol();
2112 Sym.print(OS, MAI);
2113 return;
2114 }
2115
2116 case MCExpr::Unary: {
2117 const MCUnaryExpr &UE = cast<MCUnaryExpr>(Val: Expr);
2118 switch (UE.getOpcode()) {
2119 case MCUnaryExpr::LNot: OS << '!'; break;
2120 case MCUnaryExpr::Minus: OS << '-'; break;
2121 case MCUnaryExpr::Not: OS << '~'; break;
2122 case MCUnaryExpr::Plus: OS << '+'; break;
2123 }
2124 printMCExpr(Expr: *UE.getSubExpr(), OS);
2125 return;
2126 }
2127
2128 case MCExpr::Binary: {
2129 const MCBinaryExpr &BE = cast<MCBinaryExpr>(Val: Expr);
2130
2131 // Only print parens around the LHS if it is non-trivial.
2132 if (isa<MCConstantExpr>(Val: BE.getLHS()) || isa<MCSymbolRefExpr>(Val: BE.getLHS()) ||
2133 isa<NVPTXGenericMCSymbolRefExpr>(Val: BE.getLHS())) {
2134 printMCExpr(Expr: *BE.getLHS(), OS);
2135 } else {
2136 OS << '(';
2137 printMCExpr(Expr: *BE.getLHS(), OS);
2138 OS<< ')';
2139 }
2140
2141 switch (BE.getOpcode()) {
2142 case MCBinaryExpr::Add:
2143 // Print "X-42" instead of "X+-42".
2144 if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(Val: BE.getRHS())) {
2145 if (RHSC->getValue() < 0) {
2146 OS << RHSC->getValue();
2147 return;
2148 }
2149 }
2150
2151 OS << '+';
2152 break;
2153 default: llvm_unreachable("Unhandled binary operator");
2154 }
2155
2156 // Only print parens around the LHS if it is non-trivial.
2157 if (isa<MCConstantExpr>(Val: BE.getRHS()) || isa<MCSymbolRefExpr>(Val: BE.getRHS())) {
2158 printMCExpr(Expr: *BE.getRHS(), OS);
2159 } else {
2160 OS << '(';
2161 printMCExpr(Expr: *BE.getRHS(), OS);
2162 OS << ')';
2163 }
2164 return;
2165 }
2166 }
2167
2168 llvm_unreachable("Invalid expression kind!");
2169}
2170
2171/// PrintAsmOperand - Print out an operand for an inline asm expression.
2172///
2173bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2174 const char *ExtraCode, raw_ostream &O) {
2175 if (ExtraCode && ExtraCode[0]) {
2176 if (ExtraCode[1] != 0)
2177 return true; // Unknown modifier.
2178
2179 switch (ExtraCode[0]) {
2180 default:
2181 // See if this is a generic print operand
2182 return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, OS&: O);
2183 case 'r':
2184 break;
2185 }
2186 }
2187
2188 printOperand(MI, OpNum: OpNo, O);
2189
2190 return false;
2191}
2192
2193bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2194 unsigned OpNo,
2195 const char *ExtraCode,
2196 raw_ostream &O) {
2197 if (ExtraCode && ExtraCode[0])
2198 return true; // Unknown modifier
2199
2200 O << '[';
2201 printMemOperand(MI, OpNum: OpNo, O);
2202 O << ']';
2203
2204 return false;
2205}
2206
2207void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2208 raw_ostream &O) {
2209 const MachineOperand &MO = MI->getOperand(i: OpNum);
2210 switch (MO.getType()) {
2211 case MachineOperand::MO_Register:
2212 if (MO.getReg().isPhysical()) {
2213 if (MO.getReg() == NVPTX::VRDepot)
2214 O << DEPOTNAME << getFunctionNumber();
2215 else
2216 O << NVPTXInstPrinter::getRegisterName(Reg: MO.getReg());
2217 } else {
2218 emitVirtualRegister(vr: MO.getReg(), O);
2219 }
2220 break;
2221
2222 case MachineOperand::MO_Immediate:
2223 O << MO.getImm();
2224 break;
2225
2226 case MachineOperand::MO_FPImmediate:
2227 printFPConstant(Fp: MO.getFPImm(), O);
2228 break;
2229
2230 case MachineOperand::MO_GlobalAddress:
2231 PrintSymbolOperand(MO, OS&: O);
2232 break;
2233
2234 case MachineOperand::MO_MachineBasicBlock:
2235 MO.getMBB()->getSymbol()->print(OS&: O, MAI);
2236 break;
2237
2238 default:
2239 llvm_unreachable("Operand type not supported.");
2240 }
2241}
2242
2243void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2244 raw_ostream &O, const char *Modifier) {
2245 printOperand(MI, OpNum, O);
2246
2247 if (Modifier && strcmp(s1: Modifier, s2: "add") == 0) {
2248 O << ", ";
2249 printOperand(MI, OpNum: OpNum + 1, O);
2250 } else {
2251 if (MI->getOperand(i: OpNum + 1).isImm() &&
2252 MI->getOperand(i: OpNum + 1).getImm() == 0)
2253 return; // don't print ',0' or '+0'
2254 O << "+";
2255 printOperand(MI, OpNum: OpNum + 1, O);
2256 }
2257}
2258
2259// Force static initialization.
2260extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2261 RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2262 RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2263}
2264

source code of llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp