1 | //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// |
9 | /// \file This file contains class to help build DXIL op functions. |
10 | //===----------------------------------------------------------------------===// |
11 | |
12 | #include "DXILOpBuilder.h" |
13 | #include "DXILConstants.h" |
14 | #include "llvm/IR/IRBuilder.h" |
15 | #include "llvm/IR/Module.h" |
16 | #include "llvm/Support/DXILABI.h" |
17 | #include "llvm/Support/ErrorHandling.h" |
18 | |
19 | using namespace llvm; |
20 | using namespace llvm::dxil; |
21 | |
22 | constexpr StringLiteral DXILOpNamePrefix = "dx.op." ; |
23 | |
24 | namespace { |
25 | |
26 | enum OverloadKind : uint16_t { |
27 | VOID = 1, |
28 | HALF = 1 << 1, |
29 | FLOAT = 1 << 2, |
30 | DOUBLE = 1 << 3, |
31 | I1 = 1 << 4, |
32 | I8 = 1 << 5, |
33 | I16 = 1 << 6, |
34 | I32 = 1 << 7, |
35 | I64 = 1 << 8, |
36 | UserDefineType = 1 << 9, |
37 | ObjectType = 1 << 10, |
38 | }; |
39 | |
40 | } // namespace |
41 | |
42 | static const char *getOverloadTypeName(OverloadKind Kind) { |
43 | switch (Kind) { |
44 | case OverloadKind::HALF: |
45 | return "f16" ; |
46 | case OverloadKind::FLOAT: |
47 | return "f32" ; |
48 | case OverloadKind::DOUBLE: |
49 | return "f64" ; |
50 | case OverloadKind::I1: |
51 | return "i1" ; |
52 | case OverloadKind::I8: |
53 | return "i8" ; |
54 | case OverloadKind::I16: |
55 | return "i16" ; |
56 | case OverloadKind::I32: |
57 | return "i32" ; |
58 | case OverloadKind::I64: |
59 | return "i64" ; |
60 | case OverloadKind::VOID: |
61 | case OverloadKind::ObjectType: |
62 | case OverloadKind::UserDefineType: |
63 | break; |
64 | } |
65 | llvm_unreachable("invalid overload type for name" ); |
66 | return "void" ; |
67 | } |
68 | |
69 | static OverloadKind getOverloadKind(Type *Ty) { |
70 | Type::TypeID T = Ty->getTypeID(); |
71 | switch (T) { |
72 | case Type::VoidTyID: |
73 | return OverloadKind::VOID; |
74 | case Type::HalfTyID: |
75 | return OverloadKind::HALF; |
76 | case Type::FloatTyID: |
77 | return OverloadKind::FLOAT; |
78 | case Type::DoubleTyID: |
79 | return OverloadKind::DOUBLE; |
80 | case Type::IntegerTyID: { |
81 | IntegerType *ITy = cast<IntegerType>(Val: Ty); |
82 | unsigned Bits = ITy->getBitWidth(); |
83 | switch (Bits) { |
84 | case 1: |
85 | return OverloadKind::I1; |
86 | case 8: |
87 | return OverloadKind::I8; |
88 | case 16: |
89 | return OverloadKind::I16; |
90 | case 32: |
91 | return OverloadKind::I32; |
92 | case 64: |
93 | return OverloadKind::I64; |
94 | default: |
95 | llvm_unreachable("invalid overload type" ); |
96 | return OverloadKind::VOID; |
97 | } |
98 | } |
99 | case Type::PointerTyID: |
100 | return OverloadKind::UserDefineType; |
101 | case Type::StructTyID: |
102 | return OverloadKind::ObjectType; |
103 | default: |
104 | llvm_unreachable("invalid overload type" ); |
105 | return OverloadKind::VOID; |
106 | } |
107 | } |
108 | |
109 | static std::string getTypeName(OverloadKind Kind, Type *Ty) { |
110 | if (Kind < OverloadKind::UserDefineType) { |
111 | return getOverloadTypeName(Kind); |
112 | } else if (Kind == OverloadKind::UserDefineType) { |
113 | StructType *ST = cast<StructType>(Val: Ty); |
114 | return ST->getStructName().str(); |
115 | } else if (Kind == OverloadKind::ObjectType) { |
116 | StructType *ST = cast<StructType>(Val: Ty); |
117 | return ST->getStructName().str(); |
118 | } else { |
119 | std::string Str; |
120 | raw_string_ostream OS(Str); |
121 | Ty->print(O&: OS); |
122 | return OS.str(); |
123 | } |
124 | } |
125 | |
126 | // Static properties. |
127 | struct OpCodeProperty { |
128 | dxil::OpCode OpCode; |
129 | // Offset in DXILOpCodeNameTable. |
130 | unsigned OpCodeNameOffset; |
131 | dxil::OpCodeClass OpCodeClass; |
132 | // Offset in DXILOpCodeClassNameTable. |
133 | unsigned OpCodeClassNameOffset; |
134 | uint16_t OverloadTys; |
135 | llvm::Attribute::AttrKind FuncAttr; |
136 | int OverloadParamIndex; // parameter index which control the overload. |
137 | // When < 0, should be only 1 overload type. |
138 | unsigned NumOfParameters; // Number of parameters include return value. |
139 | unsigned ParameterTableOffset; // Offset in ParameterTable. |
140 | }; |
141 | |
142 | // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and |
143 | // getOpCodeParameterKind which generated by tableGen. |
144 | #define DXIL_OP_OPERATION_TABLE |
145 | #include "DXILOperation.inc" |
146 | #undef DXIL_OP_OPERATION_TABLE |
147 | |
148 | static std::string constructOverloadName(OverloadKind Kind, Type *Ty, |
149 | const OpCodeProperty &Prop) { |
150 | if (Kind == OverloadKind::VOID) { |
151 | return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); |
152 | } |
153 | return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + |
154 | getTypeName(Kind, Ty)) |
155 | .str(); |
156 | } |
157 | |
158 | static std::string constructOverloadTypeName(OverloadKind Kind, |
159 | StringRef TypeName) { |
160 | if (Kind == OverloadKind::VOID) |
161 | return TypeName.str(); |
162 | |
163 | assert(Kind < OverloadKind::UserDefineType && "invalid overload kind" ); |
164 | return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); |
165 | } |
166 | |
167 | static StructType *getOrCreateStructType(StringRef Name, |
168 | ArrayRef<Type *> EltTys, |
169 | LLVMContext &Ctx) { |
170 | StructType *ST = StructType::getTypeByName(C&: Ctx, Name); |
171 | if (ST) |
172 | return ST; |
173 | |
174 | return StructType::create(Context&: Ctx, Elements: EltTys, Name); |
175 | } |
176 | |
177 | static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { |
178 | OverloadKind Kind = getOverloadKind(Ty: OverloadTy); |
179 | std::string TypeName = constructOverloadTypeName(Kind, TypeName: "dx.types.ResRet." ); |
180 | Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, |
181 | Type::getInt32Ty(C&: Ctx)}; |
182 | return getOrCreateStructType(Name: TypeName, EltTys: FieldTypes, Ctx); |
183 | } |
184 | |
185 | static StructType *getHandleType(LLVMContext &Ctx) { |
186 | return getOrCreateStructType(Name: "dx.types.Handle" , EltTys: PointerType::getUnqual(C&: Ctx), |
187 | Ctx); |
188 | } |
189 | |
190 | static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { |
191 | auto &Ctx = OverloadTy->getContext(); |
192 | switch (Kind) { |
193 | case ParameterKind::Void: |
194 | return Type::getVoidTy(C&: Ctx); |
195 | case ParameterKind::Half: |
196 | return Type::getHalfTy(C&: Ctx); |
197 | case ParameterKind::Float: |
198 | return Type::getFloatTy(C&: Ctx); |
199 | case ParameterKind::Double: |
200 | return Type::getDoubleTy(C&: Ctx); |
201 | case ParameterKind::I1: |
202 | return Type::getInt1Ty(C&: Ctx); |
203 | case ParameterKind::I8: |
204 | return Type::getInt8Ty(C&: Ctx); |
205 | case ParameterKind::I16: |
206 | return Type::getInt16Ty(C&: Ctx); |
207 | case ParameterKind::I32: |
208 | return Type::getInt32Ty(C&: Ctx); |
209 | case ParameterKind::I64: |
210 | return Type::getInt64Ty(C&: Ctx); |
211 | case ParameterKind::Overload: |
212 | return OverloadTy; |
213 | case ParameterKind::ResourceRet: |
214 | return getResRetType(OverloadTy, Ctx); |
215 | case ParameterKind::DXILHandle: |
216 | return getHandleType(Ctx); |
217 | default: |
218 | break; |
219 | } |
220 | llvm_unreachable("Invalid parameter kind" ); |
221 | return nullptr; |
222 | } |
223 | |
224 | /// Construct DXIL function type. This is the type of a function with |
225 | /// the following prototype |
226 | /// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>) |
227 | /// <param-types> are constructed from types in Prop. |
228 | /// \param Prop Structure containing DXIL Operation properties based on |
229 | /// its specification in DXIL.td. |
230 | /// \param OverloadTy Return type to be used to construct DXIL function type. |
231 | static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, |
232 | Type *ReturnTy, Type *OverloadTy) { |
233 | SmallVector<Type *> ArgTys; |
234 | |
235 | auto ParamKinds = getOpCodeParameterKind(*Prop); |
236 | |
237 | // Add ReturnTy as return type of the function |
238 | ArgTys.emplace_back(Args&: ReturnTy); |
239 | |
240 | // Add DXIL Opcode value type viz., Int32 as first argument |
241 | ArgTys.emplace_back(Args: Type::getInt32Ty(C&: OverloadTy->getContext())); |
242 | |
243 | // Add DXIL Operation parameter types as specified in DXIL properties |
244 | for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { |
245 | ParameterKind Kind = ParamKinds[I]; |
246 | ArgTys.emplace_back(Args: getTypeFromParameterKind(Kind, OverloadTy)); |
247 | } |
248 | return FunctionType::get( |
249 | Result: ArgTys[0], Params: ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), isVarArg: false); |
250 | } |
251 | |
252 | namespace llvm { |
253 | namespace dxil { |
254 | |
255 | CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, |
256 | Type *OverloadTy, |
257 | SmallVector<Value *> Args) { |
258 | const OpCodeProperty *Prop = getOpCodeProperty(OpCode); |
259 | |
260 | OverloadKind Kind = getOverloadKind(Ty: OverloadTy); |
261 | if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { |
262 | report_fatal_error(reason: "Invalid Overload Type" , /* gen_crash_diag=*/false); |
263 | } |
264 | |
265 | std::string DXILFnName = constructOverloadName(Kind, Ty: OverloadTy, Prop: *Prop); |
266 | FunctionCallee DXILFn; |
267 | // Get the function with name DXILFnName, if one exists |
268 | if (auto *Func = M.getFunction(DXILFnName)) { |
269 | DXILFn = FunctionCallee(Func); |
270 | } else { |
271 | // Construct and add a function with name DXILFnName |
272 | FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy); |
273 | DXILFn = M.getOrInsertFunction(Name: DXILFnName, T: DXILOpFT); |
274 | } |
275 | |
276 | return B.CreateCall(Callee: DXILFn, Args); |
277 | } |
278 | |
279 | Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { |
280 | |
281 | const OpCodeProperty *Prop = getOpCodeProperty(OpCode); |
282 | // If DXIL Op has no overload parameter, just return the |
283 | // precise return type specified. |
284 | if (Prop->OverloadParamIndex < 0) { |
285 | auto &Ctx = FT->getContext(); |
286 | switch (Prop->OverloadTys) { |
287 | case OverloadKind::VOID: |
288 | return Type::getVoidTy(C&: Ctx); |
289 | case OverloadKind::HALF: |
290 | return Type::getHalfTy(C&: Ctx); |
291 | case OverloadKind::FLOAT: |
292 | return Type::getFloatTy(C&: Ctx); |
293 | case OverloadKind::DOUBLE: |
294 | return Type::getDoubleTy(C&: Ctx); |
295 | case OverloadKind::I1: |
296 | return Type::getInt1Ty(C&: Ctx); |
297 | case OverloadKind::I8: |
298 | return Type::getInt8Ty(C&: Ctx); |
299 | case OverloadKind::I16: |
300 | return Type::getInt16Ty(C&: Ctx); |
301 | case OverloadKind::I32: |
302 | return Type::getInt32Ty(C&: Ctx); |
303 | case OverloadKind::I64: |
304 | return Type::getInt64Ty(C&: Ctx); |
305 | default: |
306 | llvm_unreachable("invalid overload type" ); |
307 | return nullptr; |
308 | } |
309 | } |
310 | |
311 | // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). |
312 | Type *OverloadType = FT->getReturnType(); |
313 | if (Prop->OverloadParamIndex != 0) { |
314 | // Skip Return Type. |
315 | OverloadType = FT->getParamType(i: Prop->OverloadParamIndex - 1); |
316 | } |
317 | |
318 | auto ParamKinds = getOpCodeParameterKind(*Prop); |
319 | auto Kind = ParamKinds[Prop->OverloadParamIndex]; |
320 | // For ResRet and CBufferRet, OverloadTy is in field of StructType. |
321 | if (Kind == ParameterKind::CBufferRet || |
322 | Kind == ParameterKind::ResourceRet) { |
323 | auto *ST = cast<StructType>(Val: OverloadType); |
324 | OverloadType = ST->getElementType(N: 0); |
325 | } |
326 | return OverloadType; |
327 | } |
328 | |
329 | const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { |
330 | return ::getOpCodeName(DXILOp); |
331 | } |
332 | } // namespace dxil |
333 | } // namespace llvm |
334 | |