1//===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the lowering of LLVM calls to machine code calls for
10// GlobalISel.
11//
12//===----------------------------------------------------------------------===//
13
14#include "SPIRVCallLowering.h"
15#include "MCTargetDesc/SPIRVBaseInfo.h"
16#include "SPIRV.h"
17#include "SPIRVBuiltins.h"
18#include "SPIRVGlobalRegistry.h"
19#include "SPIRVISelLowering.h"
20#include "SPIRVMetadata.h"
21#include "SPIRVRegisterInfo.h"
22#include "SPIRVSubtarget.h"
23#include "SPIRVUtils.h"
24#include "llvm/CodeGen/FunctionLoweringInfo.h"
25#include "llvm/IR/IntrinsicInst.h"
26#include "llvm/IR/IntrinsicsSPIRV.h"
27#include "llvm/Support/ModRef.h"
28
29using namespace llvm;
30
31SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
32 SPIRVGlobalRegistry *GR)
33 : CallLowering(&TLI), GR(GR) {}
34
35bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
36 const Value *Val, ArrayRef<Register> VRegs,
37 FunctionLoweringInfo &FLI,
38 Register SwiftErrorVReg) const {
39 // Maybe run postponed production of types for function pointers
40 if (IndirectCalls.size() > 0) {
41 produceIndirectPtrTypes(MIRBuilder);
42 IndirectCalls.clear();
43 }
44
45 // Currently all return types should use a single register.
46 // TODO: handle the case of multiple registers.
47 if (VRegs.size() > 1)
48 return false;
49 if (Val) {
50 const auto &STI = MIRBuilder.getMF().getSubtarget();
51 return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
52 .addUse(VRegs[0])
53 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
54 *STI.getRegBankInfo());
55 }
56 MIRBuilder.buildInstr(SPIRV::OpReturn);
57 return true;
58}
59
60// Based on the LLVM function attributes, get a SPIR-V FunctionControl.
61static uint32_t getFunctionControl(const Function &F) {
62 MemoryEffects MemEffects = F.getMemoryEffects();
63
64 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
65
66 if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
67 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
68 else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
69 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
70
71 if (MemEffects.doesNotAccessMemory())
72 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
73 else if (MemEffects.onlyReadsMemory())
74 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
75
76 return FuncControl;
77}
78
79static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
80 if (MD->getNumOperands() > NumOp) {
81 auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: NumOp));
82 if (CMeta)
83 return dyn_cast<ConstantInt>(Val: CMeta->getValue());
84 }
85 return nullptr;
86}
87
88// If the function has pointer arguments, we are forced to re-create this
89// function type from the very beginning, changing PointerType by
90// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
91// potentially corresponds to different SPIR-V function type, effectively
92// invalidating logic behind global registry and duplicates tracker.
93static FunctionType *
94fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
95 FunctionType *FTy, const SPIRVType *SRetTy,
96 const SmallVector<SPIRVType *, 4> &SArgTys) {
97 if (F.getParent()->getNamedMetadata(Name: "spv.cloned_funcs"))
98 return FTy;
99
100 bool hasArgPtrs = false;
101 for (auto &Arg : F.args()) {
102 // check if it's an instance of a non-typed PointerType
103 if (Arg.getType()->isPointerTy()) {
104 hasArgPtrs = true;
105 break;
106 }
107 }
108 if (!hasArgPtrs) {
109 Type *RetTy = FTy->getReturnType();
110 // check if it's an instance of a non-typed PointerType
111 if (!RetTy->isPointerTy())
112 return FTy;
113 }
114
115 // re-create function type, using TypedPointerType instead of PointerType to
116 // properly trace argument types
117 const Type *RetTy = GR->getTypeForSPIRVType(Ty: SRetTy);
118 SmallVector<Type *, 4> ArgTys;
119 for (auto SArgTy : SArgTys)
120 ArgTys.push_back(Elt: const_cast<Type *>(GR->getTypeForSPIRVType(Ty: SArgTy)));
121 return FunctionType::get(Result: const_cast<Type *>(RetTy), Params: ArgTys, isVarArg: false);
122}
123
124// This code restores function args/retvalue types for composite cases
125// because the final types should still be aggregate whereas they're i32
126// during the translation to cope with aggregate flattening etc.
127static FunctionType *getOriginalFunctionType(const Function &F) {
128 auto *NamedMD = F.getParent()->getNamedMetadata(Name: "spv.cloned_funcs");
129 if (NamedMD == nullptr)
130 return F.getFunctionType();
131
132 Type *RetTy = F.getFunctionType()->getReturnType();
133 SmallVector<Type *, 4> ArgTypes;
134 for (auto &Arg : F.args())
135 ArgTypes.push_back(Elt: Arg.getType());
136
137 auto ThisFuncMDIt =
138 std::find_if(first: NamedMD->op_begin(), last: NamedMD->op_end(), pred: [&F](MDNode *N) {
139 return isa<MDString>(Val: N->getOperand(I: 0)) &&
140 cast<MDString>(Val: N->getOperand(I: 0))->getString() == F.getName();
141 });
142 // TODO: probably one function can have numerous type mutations,
143 // so we should support this.
144 if (ThisFuncMDIt != NamedMD->op_end()) {
145 auto *ThisFuncMD = *ThisFuncMDIt;
146 MDNode *MD = dyn_cast<MDNode>(Val: ThisFuncMD->getOperand(I: 1));
147 assert(MD && "MDNode operand is expected");
148 ConstantInt *Const = getConstInt(MD, NumOp: 0);
149 if (Const) {
150 auto *CMeta = dyn_cast<ConstantAsMetadata>(Val: MD->getOperand(I: 1));
151 assert(CMeta && "ConstantAsMetadata operand is expected");
152 assert(Const->getSExtValue() >= -1);
153 // Currently -1 indicates return value, greater values mean
154 // argument numbers.
155 if (Const->getSExtValue() == -1)
156 RetTy = CMeta->getType();
157 else
158 ArgTypes[Const->getSExtValue()] = CMeta->getType();
159 }
160 }
161
162 return FunctionType::get(Result: RetTy, Params: ArgTypes, isVarArg: F.isVarArg());
163}
164
165static SPIRV::AccessQualifier::AccessQualifier
166getArgAccessQual(const Function &F, unsigned ArgIdx) {
167 if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
168 return SPIRV::AccessQualifier::ReadWrite;
169
170 MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
171 if (!ArgAttribute)
172 return SPIRV::AccessQualifier::ReadWrite;
173
174 if (ArgAttribute->getString().compare("read_only") == 0)
175 return SPIRV::AccessQualifier::ReadOnly;
176 if (ArgAttribute->getString().compare("write_only") == 0)
177 return SPIRV::AccessQualifier::WriteOnly;
178 return SPIRV::AccessQualifier::ReadWrite;
179}
180
181static std::vector<SPIRV::Decoration::Decoration>
182getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
183 MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
184 if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
185 return {SPIRV::Decoration::Volatile};
186 return {};
187}
188
189static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
190 SPIRVGlobalRegistry *GR,
191 MachineIRBuilder &MIRBuilder,
192 const SPIRVSubtarget &ST) {
193 // Read argument's access qualifier from metadata or default.
194 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
195 getArgAccessQual(F, ArgIdx);
196
197 Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(i: ArgIdx);
198
199 // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
200 // be legally reassigned later).
201 if (!isPointerTy(T: OriginalArgType))
202 return GR->getOrCreateSPIRVType(BitWidth: OriginalArgType, I&: MIRBuilder, TII: ArgAccessQual);
203
204 Argument *Arg = F.getArg(i: ArgIdx);
205 Type *ArgType = Arg->getType();
206 if (isTypedPointerTy(T: ArgType)) {
207 SPIRVType *ElementType = GR->getOrCreateSPIRVType(
208 cast<TypedPointerType>(Val: ArgType)->getElementType(), MIRBuilder);
209 return GR->getOrCreateSPIRVPointerType(
210 BaseType: ElementType, MIRBuilder,
211 SClass: addressSpaceToStorageClass(AddrSpace: getPointerAddressSpace(T: ArgType), STI: ST));
212 }
213
214 // In case OriginalArgType is of untyped pointer type, there are three
215 // possibilities:
216 // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
217 // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
218 // intrinsic assigning a TargetExtType.
219 // 3) This is a pointer, try to retrieve pointer element type from a
220 // spv_assign_ptr_type intrinsic or otherwise use default pointer element
221 // type.
222 if (hasPointeeTypeAttr(Arg)) {
223 SPIRVType *ElementType =
224 GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
225 return GR->getOrCreateSPIRVPointerType(
226 BaseType: ElementType, MIRBuilder,
227 SClass: addressSpaceToStorageClass(AddrSpace: getPointerAddressSpace(T: ArgType), STI: ST));
228 }
229
230 for (auto User : Arg->users()) {
231 auto *II = dyn_cast<IntrinsicInst>(Val: User);
232 // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
233 if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
234 MetadataAsValue *VMD = cast<MetadataAsValue>(Val: II->getOperand(i_nocapture: 1));
235 Type *BuiltinType =
236 cast<ConstantAsMetadata>(Val: VMD->getMetadata())->getType();
237 assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
238 return GR->getOrCreateSPIRVType(BitWidth: BuiltinType, I&: MIRBuilder, TII: ArgAccessQual);
239 }
240
241 // Check if this is spv_assign_ptr_type assigning pointer element type.
242 if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
243 continue;
244
245 MetadataAsValue *VMD = cast<MetadataAsValue>(Val: II->getOperand(i_nocapture: 1));
246 Type *ElementTy = cast<ConstantAsMetadata>(Val: VMD->getMetadata())->getType();
247 if (isUntypedPointerTy(T: ElementTy))
248 ElementTy =
249 TypedPointerType::get(ElementType: IntegerType::getInt8Ty(C&: II->getContext()),
250 AddressSpace: getPointerAddressSpace(T: ElementTy));
251 SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder);
252 return GR->getOrCreateSPIRVPointerType(
253 BaseType: ElementType, MIRBuilder,
254 SClass: addressSpaceToStorageClass(
255 AddrSpace: cast<ConstantInt>(Val: II->getOperand(i_nocapture: 2))->getZExtValue(), STI: ST));
256 }
257
258 // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
259 // LLVM types in a consistent manner
260 if (isUntypedPointerTy(T: OriginalArgType)) {
261 OriginalArgType =
262 TypedPointerType::get(ElementType: Type::getInt8Ty(C&: F.getContext()),
263 AddressSpace: getPointerAddressSpace(T: OriginalArgType));
264 }
265 return GR->getOrCreateSPIRVType(BitWidth: OriginalArgType, I&: MIRBuilder, TII: ArgAccessQual);
266}
267
268static SPIRV::ExecutionModel::ExecutionModel
269getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
270 if (STI.isOpenCLEnv())
271 return SPIRV::ExecutionModel::Kernel;
272
273 auto attribute = F.getFnAttribute(Kind: "hlsl.shader");
274 if (!attribute.isValid()) {
275 report_fatal_error(
276 reason: "This entry point lacks mandatory hlsl.shader attribute.");
277 }
278
279 const auto value = attribute.getValueAsString();
280 if (value == "compute")
281 return SPIRV::ExecutionModel::GLCompute;
282
283 report_fatal_error(reason: "This HLSL entry point is not supported by this backend.");
284}
285
286bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
287 const Function &F,
288 ArrayRef<ArrayRef<Register>> VRegs,
289 FunctionLoweringInfo &FLI) const {
290 assert(GR && "Must initialize the SPIRV type registry before lowering args.");
291 GR->setCurrentFunc(MIRBuilder.getMF());
292
293 // Get access to information about available extensions
294 const SPIRVSubtarget *ST =
295 static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
296
297 // Assign types and names to all args, and store their types for later.
298 SmallVector<SPIRVType *, 4> ArgTypeVRegs;
299 if (VRegs.size() > 0) {
300 unsigned i = 0;
301 for (const auto &Arg : F.args()) {
302 // Currently formal args should use single registers.
303 // TODO: handle the case of multiple registers.
304 if (VRegs[i].size() > 1)
305 return false;
306 auto *SpirvTy = getArgSPIRVType(F, ArgIdx: i, GR, MIRBuilder, ST: *ST);
307 GR->assignSPIRVTypeToVReg(Type: SpirvTy, VReg: VRegs[i][0], MF&: MIRBuilder.getMF());
308 ArgTypeVRegs.push_back(Elt: SpirvTy);
309
310 if (Arg.hasName())
311 buildOpName(Target: VRegs[i][0], Name: Arg.getName(), MIRBuilder);
312 if (isPointerTy(T: Arg.getType())) {
313 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
314 if (DerefBytes != 0)
315 buildOpDecorate(VRegs[i][0], MIRBuilder,
316 SPIRV::Decoration::MaxByteOffset, {DerefBytes});
317 }
318 if (Arg.hasAttribute(Attribute::Alignment)) {
319 auto Alignment = static_cast<unsigned>(
320 Arg.getAttribute(Attribute::Alignment).getValueAsInt());
321 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
322 {Alignment});
323 }
324 if (Arg.hasAttribute(Attribute::ReadOnly)) {
325 auto Attr =
326 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
327 buildOpDecorate(VRegs[i][0], MIRBuilder,
328 SPIRV::Decoration::FuncParamAttr, {Attr});
329 }
330 if (Arg.hasAttribute(Attribute::ZExt)) {
331 auto Attr =
332 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
333 buildOpDecorate(VRegs[i][0], MIRBuilder,
334 SPIRV::Decoration::FuncParamAttr, {Attr});
335 }
336 if (Arg.hasAttribute(Attribute::NoAlias)) {
337 auto Attr =
338 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
339 buildOpDecorate(VRegs[i][0], MIRBuilder,
340 SPIRV::Decoration::FuncParamAttr, {Attr});
341 }
342 if (Arg.hasAttribute(Attribute::ByVal)) {
343 auto Attr =
344 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
345 buildOpDecorate(VRegs[i][0], MIRBuilder,
346 SPIRV::Decoration::FuncParamAttr, {Attr});
347 }
348
349 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
350 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
351 getKernelArgTypeQual(F, i);
352 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
353 buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
354 }
355
356 MDNode *Node = F.getMetadata(Kind: "spirv.ParameterDecorations");
357 if (Node && i < Node->getNumOperands() &&
358 isa<MDNode>(Val: Node->getOperand(I: i))) {
359 MDNode *MD = cast<MDNode>(Val: Node->getOperand(I: i));
360 for (const MDOperand &MDOp : MD->operands()) {
361 MDNode *MD2 = dyn_cast<MDNode>(Val: MDOp);
362 assert(MD2 && "Metadata operand is expected");
363 ConstantInt *Const = getConstInt(MD: MD2, NumOp: 0);
364 assert(Const && "MDOperand should be ConstantInt");
365 auto Dec =
366 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
367 std::vector<uint32_t> DecVec;
368 for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
369 ConstantInt *Const = getConstInt(MD: MD2, NumOp: j);
370 assert(Const && "MDOperand should be ConstantInt");
371 DecVec.push_back(x: static_cast<uint32_t>(Const->getZExtValue()));
372 }
373 buildOpDecorate(Reg: VRegs[i][0], MIRBuilder, Dec, DecArgs: DecVec);
374 }
375 }
376 ++i;
377 }
378 }
379
380 auto MRI = MIRBuilder.getMRI();
381 Register FuncVReg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 32));
382 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
383 if (F.isDeclaration())
384 GR->add(F: &F, MF: &MIRBuilder.getMF(), R: FuncVReg);
385 FunctionType *FTy = getOriginalFunctionType(F);
386 Type *FRetTy = FTy->getReturnType();
387 if (isUntypedPointerTy(T: FRetTy)) {
388 if (Type *FRetElemTy = GR->findDeducedElementType(Val: &F)) {
389 TypedPointerType *DerivedTy =
390 TypedPointerType::get(ElementType: FRetElemTy, AddressSpace: getPointerAddressSpace(T: FRetTy));
391 GR->addReturnType(ArgF: &F, DerivedTy);
392 FRetTy = DerivedTy;
393 }
394 }
395 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
396 FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, SRetTy: RetTy, SArgTys: ArgTypeVRegs);
397 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
398 Ty: FTy, RetType: RetTy, ArgTypes: ArgTypeVRegs, MIRBuilder);
399 uint32_t FuncControl = getFunctionControl(F);
400
401 // Add OpFunction instruction
402 MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
403 .addDef(FuncVReg)
404 .addUse(GR->getSPIRVTypeID(RetTy))
405 .addImm(FuncControl)
406 .addUse(GR->getSPIRVTypeID(FuncTy));
407 GR->recordFunctionDefinition(F: &F, MO: &MB.getInstr()->getOperand(i: 0));
408
409 // Add OpFunctionParameter instructions
410 int i = 0;
411 for (const auto &Arg : F.args()) {
412 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
413 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
414 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
415 .addDef(VRegs[i][0])
416 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
417 if (F.isDeclaration())
418 GR->add(Arg: &Arg, MF: &MIRBuilder.getMF(), R: VRegs[i][0]);
419 i++;
420 }
421 // Name the function.
422 if (F.hasName())
423 buildOpName(Target: FuncVReg, Name: F.getName(), MIRBuilder);
424
425 // Handle entry points and function linkage.
426 if (isEntryPoint(F)) {
427 const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
428 auto executionModel = getExecutionModel(STI, F);
429 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
430 .addImm(static_cast<uint32_t>(executionModel))
431 .addUse(FuncVReg);
432 addStringImm(F.getName(), MIB);
433 } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
434 F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
435 SPIRV::LinkageType::LinkageType LnkTy =
436 F.isDeclaration()
437 ? SPIRV::LinkageType::Import
438 : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage &&
439 ST->canUseExtension(
440 SPIRV::Extension::SPV_KHR_linkonce_odr)
441 ? SPIRV::LinkageType::LinkOnceODR
442 : SPIRV::LinkageType::Export);
443 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
444 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
445 }
446
447 // Handle function pointers decoration
448 bool hasFunctionPointers =
449 ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
450 if (hasFunctionPointers) {
451 if (F.hasFnAttribute(Kind: "referenced-indirectly")) {
452 assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
453 "Unexpected 'referenced-indirectly' attribute of the kernel "
454 "function");
455 buildOpDecorate(FuncVReg, MIRBuilder,
456 SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
457 }
458 }
459
460 return true;
461}
462
463// Used to postpone producing of indirect function pointer types after all
464// indirect calls info is collected
465// TODO:
466// - add a topological sort of IndirectCalls to ensure the best types knowledge
467// - we may need to fix function formal parameter types if they are opaque
468// pointers used as function pointers in these indirect calls
469void SPIRVCallLowering::produceIndirectPtrTypes(
470 MachineIRBuilder &MIRBuilder) const {
471 // Create indirect call data types if any
472 MachineFunction &MF = MIRBuilder.getMF();
473 for (auto const &IC : IndirectCalls) {
474 SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
475 SmallVector<SPIRVType *, 4> SpirvArgTypes;
476 for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
477 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
478 SpirvArgTypes.push_back(Elt: SPIRVTy);
479 if (!GR->getSPIRVTypeForVReg(VReg: IC.ArgRegs[i]))
480 GR->assignSPIRVTypeToVReg(Type: SPIRVTy, VReg: IC.ArgRegs[i], MF);
481 }
482 // SPIR-V function type:
483 FunctionType *FTy =
484 FunctionType::get(Result: const_cast<Type *>(IC.RetTy), Params: IC.ArgTys, isVarArg: false);
485 SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
486 Ty: FTy, RetType: SpirvRetTy, ArgTypes: SpirvArgTypes, MIRBuilder);
487 // SPIR-V pointer to function type:
488 SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
489 SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
490 // Correct the Callee type
491 GR->assignSPIRVTypeToVReg(Type: IndirectFuncPtrTy, VReg: IC.Callee, MF);
492 }
493}
494
495bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
496 CallLoweringInfo &Info) const {
497 // Currently call returns should have single vregs.
498 // TODO: handle the case of multiple registers.
499 if (Info.OrigRet.Regs.size() > 1)
500 return false;
501 MachineFunction &MF = MIRBuilder.getMF();
502 GR->setCurrentFunc(MF);
503 const Function *CF = nullptr;
504 std::string DemangledName;
505 const Type *OrigRetTy = Info.OrigRet.Ty;
506
507 // Emit a regular OpFunctionCall. If it's an externally declared function,
508 // be sure to emit its type and function declaration here. It will be hoisted
509 // globally later.
510 if (Info.Callee.isGlobal()) {
511 std::string FuncName = Info.Callee.getGlobal()->getName().str();
512 DemangledName = getOclOrSpirvBuiltinDemangledName(Name: FuncName);
513 CF = dyn_cast_or_null<const Function>(Val: Info.Callee.getGlobal());
514 // TODO: support constexpr casts and indirect calls.
515 if (CF == nullptr)
516 return false;
517 if (FunctionType *FTy = getOriginalFunctionType(F: *CF)) {
518 OrigRetTy = FTy->getReturnType();
519 if (isUntypedPointerTy(T: OrigRetTy)) {
520 if (auto *DerivedRetTy = GR->findReturnType(ArgF: CF))
521 OrigRetTy = DerivedRetTy;
522 }
523 }
524 }
525
526 MachineRegisterInfo *MRI = MIRBuilder.getMRI();
527 Register ResVReg =
528 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
529 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
530
531 bool isFunctionDecl = CF && CF->isDeclaration();
532 bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
533 bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
534 assert(canUseGLSL != canUseOpenCL &&
535 "Scenario where both sets are enabled is not supported.");
536
537 if (isFunctionDecl && !DemangledName.empty() &&
538 (canUseGLSL || canUseOpenCL)) {
539 SmallVector<Register, 8> ArgVRegs;
540 for (auto Arg : Info.OrigArgs) {
541 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
542 ArgVRegs.push_back(Elt: Arg.Regs[0]);
543 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
544 if (!GR->getSPIRVTypeForVReg(VReg: Arg.Regs[0]))
545 GR->assignSPIRVTypeToVReg(Type: SPIRVTy, VReg: Arg.Regs[0], MF);
546 }
547 auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
548 : SPIRV::InstructionSet::GLSL_std_450;
549 if (auto Res =
550 SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
551 ResVReg, OrigRetTy, ArgVRegs, GR))
552 return *Res;
553 }
554
555 if (isFunctionDecl && !GR->find(F: CF, MF: &MF).isValid()) {
556 // Emit the type info and forward function declaration to the first MBB
557 // to ensure VReg definition dependencies are valid across all MBBs.
558 MachineIRBuilder FirstBlockBuilder;
559 FirstBlockBuilder.setMF(MF);
560 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(N: 0));
561
562 SmallVector<ArrayRef<Register>, 8> VRegArgs;
563 SmallVector<SmallVector<Register, 1>, 8> ToInsert;
564 for (const Argument &Arg : CF->args()) {
565 if (MIRBuilder.getDataLayout().getTypeStoreSize(Ty: Arg.getType()).isZero())
566 continue; // Don't handle zero sized types.
567 Register Reg = MRI->createGenericVirtualRegister(Ty: LLT::scalar(SizeInBits: 32));
568 MRI->setRegClass(Reg, &SPIRV::IDRegClass);
569 ToInsert.push_back(Elt: {Reg});
570 VRegArgs.push_back(Elt: ToInsert.back());
571 }
572 // TODO: Reuse FunctionLoweringInfo
573 FunctionLoweringInfo FuncInfo;
574 lowerFormalArguments(MIRBuilder&: FirstBlockBuilder, F: *CF, VRegs: VRegArgs, FLI&: FuncInfo);
575 }
576
577 unsigned CallOp;
578 if (Info.CB->isIndirectCall()) {
579 if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
580 report_fatal_error(reason: "An indirect call is encountered but SPIR-V without "
581 "extensions does not support it",
582 gen_crash_diag: false);
583 // Set instruction operation according to SPV_INTEL_function_pointers
584 CallOp = SPIRV::OpFunctionPointerCallINTEL;
585 // Collect information about the indirect call to support possible
586 // specification of opaque ptr types of parent function's parameters
587 Register CalleeReg = Info.Callee.getReg();
588 if (CalleeReg.isValid()) {
589 SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
590 IndirectCall.Callee = CalleeReg;
591 IndirectCall.RetTy = OrigRetTy;
592 for (const auto &Arg : Info.OrigArgs) {
593 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
594 IndirectCall.ArgTys.push_back(Elt: Arg.Ty);
595 IndirectCall.ArgRegs.push_back(Elt: Arg.Regs[0]);
596 }
597 IndirectCalls.push_back(Elt: IndirectCall);
598 }
599 } else {
600 // Emit a regular OpFunctionCall
601 CallOp = SPIRV::OpFunctionCall;
602 }
603
604 // Make sure there's a valid return reg, even for functions returning void.
605 if (!ResVReg.isValid())
606 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
607 SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
608
609 // Emit the call instruction and its args.
610 auto MIB = MIRBuilder.buildInstr(Opcode: CallOp)
611 .addDef(RegNo: ResVReg)
612 .addUse(RegNo: GR->getSPIRVTypeID(SpirvType: RetType))
613 .add(MO: Info.Callee);
614
615 for (const auto &Arg : Info.OrigArgs) {
616 // Currently call args should have single vregs.
617 if (Arg.Regs.size() > 1)
618 return false;
619 MIB.addUse(Arg.Regs[0]);
620 }
621 return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
622 *ST->getRegBankInfo());
623}
624

source code of llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp