1 | //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// |
9 | /// \file This file contains DXIL intrinsic expansions for those that don't have |
10 | // opcodes in DirectX Intermediate Language (DXIL). |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "DXILIntrinsicExpansion.h" |
14 | #include "DirectX.h" |
15 | #include "llvm/ADT/STLExtras.h" |
16 | #include "llvm/ADT/SmallVector.h" |
17 | #include "llvm/CodeGen/Passes.h" |
18 | #include "llvm/IR/IRBuilder.h" |
19 | #include "llvm/IR/Instruction.h" |
20 | #include "llvm/IR/Instructions.h" |
21 | #include "llvm/IR/Intrinsics.h" |
22 | #include "llvm/IR/IntrinsicsDirectX.h" |
23 | #include "llvm/IR/Module.h" |
24 | #include "llvm/IR/PassManager.h" |
25 | #include "llvm/IR/Type.h" |
26 | #include "llvm/Pass.h" |
27 | #include "llvm/Support/ErrorHandling.h" |
28 | #include "llvm/Support/MathExtras.h" |
29 | |
30 | #define DEBUG_TYPE "dxil-intrinsic-expansion" |
31 | |
32 | using namespace llvm; |
33 | |
34 | static bool isIntrinsicExpansion(Function &F) { |
35 | switch (F.getIntrinsicID()) { |
36 | case Intrinsic::abs: |
37 | case Intrinsic::exp: |
38 | case Intrinsic::log: |
39 | case Intrinsic::log10: |
40 | case Intrinsic::pow: |
41 | case Intrinsic::dx_any: |
42 | case Intrinsic::dx_clamp: |
43 | case Intrinsic::dx_uclamp: |
44 | case Intrinsic::dx_lerp: |
45 | case Intrinsic::dx_sdot: |
46 | case Intrinsic::dx_udot: |
47 | return true; |
48 | } |
49 | return false; |
50 | } |
51 | |
52 | static bool expandAbs(CallInst *Orig) { |
53 | Value *X = Orig->getOperand(i_nocapture: 0); |
54 | IRBuilder<> Builder(Orig->getParent()); |
55 | Builder.SetInsertPoint(Orig); |
56 | Type *Ty = X->getType(); |
57 | Type *EltTy = Ty->getScalarType(); |
58 | Constant *Zero = Ty->isVectorTy() |
59 | ? ConstantVector::getSplat( |
60 | EC: ElementCount::getFixed( |
61 | MinVal: cast<FixedVectorType>(Val: Ty)->getNumElements()), |
62 | Elt: ConstantInt::get(Ty: EltTy, V: 0)) |
63 | : ConstantInt::get(Ty: EltTy, V: 0); |
64 | auto *V = Builder.CreateSub(LHS: Zero, RHS: X); |
65 | auto *MaxCall = |
66 | Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max" ); |
67 | Orig->replaceAllUsesWith(V: MaxCall); |
68 | Orig->eraseFromParent(); |
69 | return true; |
70 | } |
71 | |
72 | static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { |
73 | assert(DotIntrinsic == Intrinsic::dx_sdot || |
74 | DotIntrinsic == Intrinsic::dx_udot); |
75 | Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot |
76 | ? Intrinsic::dx_imad |
77 | : Intrinsic::dx_umad; |
78 | Value *A = Orig->getOperand(i_nocapture: 0); |
79 | Value *B = Orig->getOperand(i_nocapture: 1); |
80 | Type *ATy = A->getType(); |
81 | Type *BTy = B->getType(); |
82 | assert(ATy->isVectorTy() && BTy->isVectorTy()); |
83 | |
84 | IRBuilder<> Builder(Orig->getParent()); |
85 | Builder.SetInsertPoint(Orig); |
86 | |
87 | auto *AVec = dyn_cast<FixedVectorType>(Val: A->getType()); |
88 | Value *Elt0 = Builder.CreateExtractElement(Vec: A, Idx: (uint64_t)0); |
89 | Value *Elt1 = Builder.CreateExtractElement(Vec: B, Idx: (uint64_t)0); |
90 | Value *Result = Builder.CreateMul(LHS: Elt0, RHS: Elt1); |
91 | for (unsigned I = 1; I < AVec->getNumElements(); I++) { |
92 | Elt0 = Builder.CreateExtractElement(Vec: A, Idx: I); |
93 | Elt1 = Builder.CreateExtractElement(Vec: B, Idx: I); |
94 | Result = Builder.CreateIntrinsic(RetTy: Result->getType(), ID: MadIntrinsic, |
95 | Args: ArrayRef<Value *>{Elt0, Elt1, Result}, |
96 | FMFSource: nullptr, Name: "dx.mad" ); |
97 | } |
98 | Orig->replaceAllUsesWith(V: Result); |
99 | Orig->eraseFromParent(); |
100 | return true; |
101 | } |
102 | |
103 | static bool expandExpIntrinsic(CallInst *Orig) { |
104 | Value *X = Orig->getOperand(i_nocapture: 0); |
105 | IRBuilder<> Builder(Orig->getParent()); |
106 | Builder.SetInsertPoint(Orig); |
107 | Type *Ty = X->getType(); |
108 | Type *EltTy = Ty->getScalarType(); |
109 | Constant *Log2eConst = |
110 | Ty->isVectorTy() ? ConstantVector::getSplat( |
111 | EC: ElementCount::getFixed( |
112 | MinVal: cast<FixedVectorType>(Val: Ty)->getNumElements()), |
113 | Elt: ConstantFP::get(Ty: EltTy, V: numbers::log2ef)) |
114 | : ConstantFP::get(Ty: EltTy, V: numbers::log2ef); |
115 | Value *NewX = Builder.CreateFMul(L: Log2eConst, R: X); |
116 | auto *Exp2Call = |
117 | Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2" ); |
118 | Exp2Call->setTailCall(Orig->isTailCall()); |
119 | Exp2Call->setAttributes(Orig->getAttributes()); |
120 | Orig->replaceAllUsesWith(V: Exp2Call); |
121 | Orig->eraseFromParent(); |
122 | return true; |
123 | } |
124 | |
125 | static bool expandAnyIntrinsic(CallInst *Orig) { |
126 | Value *X = Orig->getOperand(i_nocapture: 0); |
127 | IRBuilder<> Builder(Orig->getParent()); |
128 | Builder.SetInsertPoint(Orig); |
129 | Type *Ty = X->getType(); |
130 | Type *EltTy = Ty->getScalarType(); |
131 | |
132 | if (!Ty->isVectorTy()) { |
133 | Value *Cond = EltTy->isFloatingPointTy() |
134 | ? Builder.CreateFCmpUNE(LHS: X, RHS: ConstantFP::get(Ty: EltTy, V: 0)) |
135 | : Builder.CreateICmpNE(LHS: X, RHS: ConstantInt::get(Ty: EltTy, V: 0)); |
136 | Orig->replaceAllUsesWith(V: Cond); |
137 | } else { |
138 | auto *XVec = dyn_cast<FixedVectorType>(Val: Ty); |
139 | Value *Cond = |
140 | EltTy->isFloatingPointTy() |
141 | ? Builder.CreateFCmpUNE( |
142 | LHS: X, RHS: ConstantVector::getSplat( |
143 | EC: ElementCount::getFixed(MinVal: XVec->getNumElements()), |
144 | Elt: ConstantFP::get(Ty: EltTy, V: 0))) |
145 | : Builder.CreateICmpNE( |
146 | LHS: X, RHS: ConstantVector::getSplat( |
147 | EC: ElementCount::getFixed(MinVal: XVec->getNumElements()), |
148 | Elt: ConstantInt::get(Ty: EltTy, V: 0))); |
149 | Value *Result = Builder.CreateExtractElement(Vec: Cond, Idx: (uint64_t)0); |
150 | for (unsigned I = 1; I < XVec->getNumElements(); I++) { |
151 | Value *Elt = Builder.CreateExtractElement(Vec: Cond, Idx: I); |
152 | Result = Builder.CreateOr(LHS: Result, RHS: Elt); |
153 | } |
154 | Orig->replaceAllUsesWith(V: Result); |
155 | } |
156 | Orig->eraseFromParent(); |
157 | return true; |
158 | } |
159 | |
160 | static bool expandLerpIntrinsic(CallInst *Orig) { |
161 | Value *X = Orig->getOperand(i_nocapture: 0); |
162 | Value *Y = Orig->getOperand(i_nocapture: 1); |
163 | Value *S = Orig->getOperand(i_nocapture: 2); |
164 | IRBuilder<> Builder(Orig->getParent()); |
165 | Builder.SetInsertPoint(Orig); |
166 | auto *V = Builder.CreateFSub(L: Y, R: X); |
167 | V = Builder.CreateFMul(L: S, R: V); |
168 | auto *Result = Builder.CreateFAdd(L: X, R: V, Name: "dx.lerp" ); |
169 | Orig->replaceAllUsesWith(V: Result); |
170 | Orig->eraseFromParent(); |
171 | return true; |
172 | } |
173 | |
174 | static bool expandLogIntrinsic(CallInst *Orig, |
175 | float LogConstVal = numbers::ln2f) { |
176 | Value *X = Orig->getOperand(i_nocapture: 0); |
177 | IRBuilder<> Builder(Orig->getParent()); |
178 | Builder.SetInsertPoint(Orig); |
179 | Type *Ty = X->getType(); |
180 | Type *EltTy = Ty->getScalarType(); |
181 | Constant *Ln2Const = |
182 | Ty->isVectorTy() ? ConstantVector::getSplat( |
183 | EC: ElementCount::getFixed( |
184 | MinVal: cast<FixedVectorType>(Val: Ty)->getNumElements()), |
185 | Elt: ConstantFP::get(Ty: EltTy, V: LogConstVal)) |
186 | : ConstantFP::get(Ty: EltTy, V: LogConstVal); |
187 | auto *Log2Call = |
188 | Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2" ); |
189 | Log2Call->setTailCall(Orig->isTailCall()); |
190 | Log2Call->setAttributes(Orig->getAttributes()); |
191 | auto *Result = Builder.CreateFMul(L: Ln2Const, R: Log2Call); |
192 | Orig->replaceAllUsesWith(V: Result); |
193 | Orig->eraseFromParent(); |
194 | return true; |
195 | } |
196 | static bool expandLog10Intrinsic(CallInst *Orig) { |
197 | return expandLogIntrinsic(Orig, LogConstVal: numbers::ln2f / numbers::ln10f); |
198 | } |
199 | |
200 | static bool expandPowIntrinsic(CallInst *Orig) { |
201 | |
202 | Value *X = Orig->getOperand(i_nocapture: 0); |
203 | Value *Y = Orig->getOperand(i_nocapture: 1); |
204 | Type *Ty = X->getType(); |
205 | IRBuilder<> Builder(Orig->getParent()); |
206 | Builder.SetInsertPoint(Orig); |
207 | |
208 | auto *Log2Call = |
209 | Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2" ); |
210 | auto *Mul = Builder.CreateFMul(L: Log2Call, R: Y); |
211 | auto *Exp2Call = |
212 | Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2" ); |
213 | Exp2Call->setTailCall(Orig->isTailCall()); |
214 | Exp2Call->setAttributes(Orig->getAttributes()); |
215 | Orig->replaceAllUsesWith(V: Exp2Call); |
216 | Orig->eraseFromParent(); |
217 | return true; |
218 | } |
219 | |
220 | static Intrinsic::ID getMaxForClamp(Type *ElemTy, |
221 | Intrinsic::ID ClampIntrinsic) { |
222 | if (ClampIntrinsic == Intrinsic::dx_uclamp) |
223 | return Intrinsic::umax; |
224 | assert(ClampIntrinsic == Intrinsic::dx_clamp); |
225 | if (ElemTy->isVectorTy()) |
226 | ElemTy = ElemTy->getScalarType(); |
227 | if (ElemTy->isIntegerTy()) |
228 | return Intrinsic::smax; |
229 | assert(ElemTy->isFloatingPointTy()); |
230 | return Intrinsic::maxnum; |
231 | } |
232 | |
233 | static Intrinsic::ID getMinForClamp(Type *ElemTy, |
234 | Intrinsic::ID ClampIntrinsic) { |
235 | if (ClampIntrinsic == Intrinsic::dx_uclamp) |
236 | return Intrinsic::umin; |
237 | assert(ClampIntrinsic == Intrinsic::dx_clamp); |
238 | if (ElemTy->isVectorTy()) |
239 | ElemTy = ElemTy->getScalarType(); |
240 | if (ElemTy->isIntegerTy()) |
241 | return Intrinsic::smin; |
242 | assert(ElemTy->isFloatingPointTy()); |
243 | return Intrinsic::minnum; |
244 | } |
245 | |
246 | static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) { |
247 | Value *X = Orig->getOperand(i_nocapture: 0); |
248 | Value *Min = Orig->getOperand(i_nocapture: 1); |
249 | Value *Max = Orig->getOperand(i_nocapture: 2); |
250 | Type *Ty = X->getType(); |
251 | IRBuilder<> Builder(Orig->getParent()); |
252 | Builder.SetInsertPoint(Orig); |
253 | auto *MaxCall = Builder.CreateIntrinsic( |
254 | RetTy: Ty, ID: getMaxForClamp(ElemTy: Ty, ClampIntrinsic), Args: {X, Min}, FMFSource: nullptr, Name: "dx.max" ); |
255 | auto *MinCall = |
256 | Builder.CreateIntrinsic(RetTy: Ty, ID: getMinForClamp(ElemTy: Ty, ClampIntrinsic), |
257 | Args: {MaxCall, Max}, FMFSource: nullptr, Name: "dx.min" ); |
258 | |
259 | Orig->replaceAllUsesWith(V: MinCall); |
260 | Orig->eraseFromParent(); |
261 | return true; |
262 | } |
263 | |
264 | static bool expandIntrinsic(Function &F, CallInst *Orig) { |
265 | switch (F.getIntrinsicID()) { |
266 | case Intrinsic::abs: |
267 | return expandAbs(Orig); |
268 | case Intrinsic::exp: |
269 | return expandExpIntrinsic(Orig); |
270 | case Intrinsic::log: |
271 | return expandLogIntrinsic(Orig); |
272 | case Intrinsic::log10: |
273 | return expandLog10Intrinsic(Orig); |
274 | case Intrinsic::pow: |
275 | return expandPowIntrinsic(Orig); |
276 | case Intrinsic::dx_any: |
277 | return expandAnyIntrinsic(Orig); |
278 | case Intrinsic::dx_uclamp: |
279 | case Intrinsic::dx_clamp: |
280 | return expandClampIntrinsic(Orig, ClampIntrinsic: F.getIntrinsicID()); |
281 | case Intrinsic::dx_lerp: |
282 | return expandLerpIntrinsic(Orig); |
283 | case Intrinsic::dx_sdot: |
284 | case Intrinsic::dx_udot: |
285 | return expandIntegerDot(Orig, DotIntrinsic: F.getIntrinsicID()); |
286 | } |
287 | return false; |
288 | } |
289 | |
290 | static bool expansionIntrinsics(Module &M) { |
291 | for (auto &F : make_early_inc_range(Range: M.functions())) { |
292 | if (!isIntrinsicExpansion(F)) |
293 | continue; |
294 | bool IntrinsicExpanded = false; |
295 | for (User *U : make_early_inc_range(Range: F.users())) { |
296 | auto *IntrinsicCall = dyn_cast<CallInst>(Val: U); |
297 | if (!IntrinsicCall) |
298 | continue; |
299 | IntrinsicExpanded = expandIntrinsic(F, Orig: IntrinsicCall); |
300 | } |
301 | if (F.user_empty() && IntrinsicExpanded) |
302 | F.eraseFromParent(); |
303 | } |
304 | return true; |
305 | } |
306 | |
307 | PreservedAnalyses DXILIntrinsicExpansion::run(Module &M, |
308 | ModuleAnalysisManager &) { |
309 | if (expansionIntrinsics(M)) |
310 | return PreservedAnalyses::none(); |
311 | return PreservedAnalyses::all(); |
312 | } |
313 | |
314 | bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) { |
315 | return expansionIntrinsics(M); |
316 | } |
317 | |
318 | char DXILIntrinsicExpansionLegacy::ID = 0; |
319 | |
320 | INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE, |
321 | "DXIL Intrinsic Expansion" , false, false) |
322 | INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE, |
323 | "DXIL Intrinsic Expansion" , false, false) |
324 | |
325 | ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() { |
326 | return new DXILIntrinsicExpansionLegacy(); |
327 | } |
328 | |