1 | // |
2 | // Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved. |
3 | // Use of this source code is governed by a BSD-style license that can be |
4 | // found in the LICENSE file. |
5 | // |
6 | |
7 | // Analysis of the AST needed for HLSL generation |
8 | |
9 | #include "compiler/translator/ASTMetadataHLSL.h" |
10 | |
11 | #include "compiler/translator/CallDAG.h" |
12 | #include "compiler/translator/SymbolTable.h" |
13 | |
14 | namespace |
15 | { |
16 | |
17 | // Class used to traverse the AST of a function definition, checking if the |
18 | // function uses a gradient, and writing the set of control flow using gradients. |
19 | // It assumes that the analysis has already been made for the function's |
20 | // callees. |
21 | class PullGradient : public TIntermTraverser |
22 | { |
23 | public: |
24 | PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag) |
25 | : TIntermTraverser(true, false, true), |
26 | mMetadataList(metadataList), |
27 | mMetadata(&(*metadataList)[index]), |
28 | mIndex(index), |
29 | mDag(dag) |
30 | { |
31 | ASSERT(index < metadataList->size()); |
32 | } |
33 | |
34 | void traverse(TIntermAggregate *node) |
35 | { |
36 | node->traverse(this); |
37 | ASSERT(mParents.empty()); |
38 | } |
39 | |
40 | // Called when a gradient operation or a call to a function using a gradient is found. |
41 | void onGradient() |
42 | { |
43 | mMetadata->mUsesGradient = true; |
44 | // Mark the latest control flow as using a gradient. |
45 | if (!mParents.empty()) |
46 | { |
47 | mMetadata->mControlFlowsContainingGradient.insert(mParents.back()); |
48 | } |
49 | } |
50 | |
51 | void visitControlFlow(Visit visit, TIntermNode *node) |
52 | { |
53 | if (visit == PreVisit) |
54 | { |
55 | mParents.push_back(node); |
56 | } |
57 | else if (visit == PostVisit) |
58 | { |
59 | ASSERT(mParents.back() == node); |
60 | mParents.pop_back(); |
61 | // A control flow's using a gradient means its parents are too. |
62 | if (mMetadata->mControlFlowsContainingGradient.count(node)> 0 && !mParents.empty()) |
63 | { |
64 | mMetadata->mControlFlowsContainingGradient.insert(mParents.back()); |
65 | } |
66 | } |
67 | } |
68 | |
69 | bool visitLoop(Visit visit, TIntermLoop *loop) override |
70 | { |
71 | visitControlFlow(visit, loop); |
72 | return true; |
73 | } |
74 | |
75 | bool visitSelection(Visit visit, TIntermSelection *selection) override |
76 | { |
77 | visitControlFlow(visit, selection); |
78 | return true; |
79 | } |
80 | |
81 | bool visitUnary(Visit visit, TIntermUnary *node) override |
82 | { |
83 | if (visit == PreVisit) |
84 | { |
85 | switch (node->getOp()) |
86 | { |
87 | case EOpDFdx: |
88 | case EOpDFdy: |
89 | onGradient(); |
90 | default: |
91 | break; |
92 | } |
93 | } |
94 | |
95 | return true; |
96 | } |
97 | |
98 | bool visitAggregate(Visit visit, TIntermAggregate *node) override |
99 | { |
100 | if (visit == PreVisit) |
101 | { |
102 | if (node->getOp() == EOpFunctionCall) |
103 | { |
104 | if (node->isUserDefined()) |
105 | { |
106 | size_t calleeIndex = mDag.findIndex(node); |
107 | ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex); |
108 | |
109 | if ((*mMetadataList)[calleeIndex].mUsesGradient) { |
110 | onGradient(); |
111 | } |
112 | } |
113 | else |
114 | { |
115 | TString name = TFunction::unmangleName(node->getName()); |
116 | |
117 | if (name == "texture2D" || |
118 | name == "texture2DProj" || |
119 | name == "textureCube" ) |
120 | { |
121 | onGradient(); |
122 | } |
123 | } |
124 | } |
125 | } |
126 | |
127 | return true; |
128 | } |
129 | |
130 | private: |
131 | MetadataList *mMetadataList; |
132 | ASTMetadataHLSL *mMetadata; |
133 | size_t mIndex; |
134 | const CallDAG &mDag; |
135 | |
136 | // Contains a stack of the control flow nodes that are parents of the node being |
137 | // currently visited. It is used to mark control flows using a gradient. |
138 | std::vector<TIntermNode*> mParents; |
139 | }; |
140 | |
141 | // Traverses the AST of a function definition, assuming it has already been used to |
142 | // traverse the callees of that function; computes the discontinuous loops and the if |
143 | // statements that contain a discontinuous loop in their call graph. |
144 | class PullComputeDiscontinuousLoops : public TIntermTraverser |
145 | { |
146 | public: |
147 | PullComputeDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag) |
148 | : TIntermTraverser(true, false, true), |
149 | mMetadataList(metadataList), |
150 | mMetadata(&(*metadataList)[index]), |
151 | mIndex(index), |
152 | mDag(dag) |
153 | { |
154 | } |
155 | |
156 | void traverse(TIntermAggregate *node) |
157 | { |
158 | node->traverse(this); |
159 | ASSERT(mLoopsAndSwitches.empty()); |
160 | ASSERT(mIfs.empty()); |
161 | } |
162 | |
163 | // Called when a discontinuous loop or a call to a function with a discontinuous loop |
164 | // in its call graph is found. |
165 | void onDiscontinuousLoop() |
166 | { |
167 | mMetadata->mHasDiscontinuousLoopInCallGraph = true; |
168 | // Mark the latest if as using a discontinuous loop. |
169 | if (!mIfs.empty()) |
170 | { |
171 | mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back()); |
172 | } |
173 | } |
174 | |
175 | bool visitLoop(Visit visit, TIntermLoop *loop) override |
176 | { |
177 | if (visit == PreVisit) |
178 | { |
179 | mLoopsAndSwitches.push_back(loop); |
180 | } |
181 | else if (visit == PostVisit) |
182 | { |
183 | ASSERT(mLoopsAndSwitches.back() == loop); |
184 | mLoopsAndSwitches.pop_back(); |
185 | } |
186 | |
187 | return true; |
188 | } |
189 | |
190 | bool visitSelection(Visit visit, TIntermSelection *node) override |
191 | { |
192 | if (visit == PreVisit) |
193 | { |
194 | mIfs.push_back(node); |
195 | } |
196 | else if (visit == PostVisit) |
197 | { |
198 | ASSERT(mIfs.back() == node); |
199 | mIfs.pop_back(); |
200 | // An if using a discontinuous loop means its parents ifs are also discontinuous. |
201 | if (mMetadata->mIfsContainingDiscontinuousLoop.count(node) > 0 && !mIfs.empty()) |
202 | { |
203 | mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back()); |
204 | } |
205 | } |
206 | |
207 | return true; |
208 | } |
209 | |
210 | bool visitBranch(Visit visit, TIntermBranch *node) override |
211 | { |
212 | if (visit == PreVisit) |
213 | { |
214 | switch (node->getFlowOp()) |
215 | { |
216 | case EOpBreak: |
217 | { |
218 | ASSERT(!mLoopsAndSwitches.empty()); |
219 | TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode(); |
220 | if (loop != nullptr) |
221 | { |
222 | mMetadata->mDiscontinuousLoops.insert(loop); |
223 | onDiscontinuousLoop(); |
224 | } |
225 | } |
226 | break; |
227 | case EOpContinue: |
228 | { |
229 | ASSERT(!mLoopsAndSwitches.empty()); |
230 | TIntermLoop *loop = nullptr; |
231 | size_t i = mLoopsAndSwitches.size(); |
232 | while (loop == nullptr && i > 0) |
233 | { |
234 | --i; |
235 | loop = mLoopsAndSwitches.at(i)->getAsLoopNode(); |
236 | } |
237 | ASSERT(loop != nullptr); |
238 | mMetadata->mDiscontinuousLoops.insert(loop); |
239 | onDiscontinuousLoop(); |
240 | } |
241 | break; |
242 | case EOpKill: |
243 | case EOpReturn: |
244 | // A return or discard jumps out of all the enclosing loops |
245 | if (!mLoopsAndSwitches.empty()) |
246 | { |
247 | for (TIntermNode* node : mLoopsAndSwitches) |
248 | { |
249 | TIntermLoop *loop = node->getAsLoopNode(); |
250 | if (loop) |
251 | { |
252 | mMetadata->mDiscontinuousLoops.insert(loop); |
253 | } |
254 | } |
255 | onDiscontinuousLoop(); |
256 | } |
257 | break; |
258 | default: |
259 | UNREACHABLE(); |
260 | } |
261 | } |
262 | |
263 | return true; |
264 | } |
265 | |
266 | bool visitAggregate(Visit visit, TIntermAggregate *node) override |
267 | { |
268 | if (visit == PreVisit && node->getOp() == EOpFunctionCall) |
269 | { |
270 | if (node->isUserDefined()) |
271 | { |
272 | size_t calleeIndex = mDag.findIndex(node); |
273 | ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex); |
274 | |
275 | if ((*mMetadataList)[calleeIndex].mHasDiscontinuousLoopInCallGraph) |
276 | { |
277 | onDiscontinuousLoop(); |
278 | } |
279 | } |
280 | } |
281 | |
282 | return true; |
283 | } |
284 | |
285 | bool visitSwitch(Visit visit, TIntermSwitch *node) override |
286 | { |
287 | if (visit == PreVisit) |
288 | { |
289 | mLoopsAndSwitches.push_back(node); |
290 | } |
291 | else if (visit == PostVisit) |
292 | { |
293 | ASSERT(mLoopsAndSwitches.back() == node); |
294 | mLoopsAndSwitches.pop_back(); |
295 | } |
296 | return true; |
297 | } |
298 | |
299 | private: |
300 | MetadataList *mMetadataList; |
301 | ASTMetadataHLSL *mMetadata; |
302 | size_t mIndex; |
303 | const CallDAG &mDag; |
304 | |
305 | std::vector<TIntermNode*> mLoopsAndSwitches; |
306 | std::vector<TIntermSelection*> mIfs; |
307 | }; |
308 | |
309 | // Tags all the functions called in a discontinuous loop |
310 | class PushDiscontinuousLoops : public TIntermTraverser |
311 | { |
312 | public: |
313 | PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag) |
314 | : TIntermTraverser(true, true, true), |
315 | mMetadataList(metadataList), |
316 | mMetadata(&(*metadataList)[index]), |
317 | mIndex(index), |
318 | mDag(dag), |
319 | mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0) |
320 | { |
321 | } |
322 | |
323 | void traverse(TIntermAggregate *node) |
324 | { |
325 | node->traverse(this); |
326 | ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)); |
327 | } |
328 | |
329 | bool visitLoop(Visit visit, TIntermLoop *loop) override |
330 | { |
331 | bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0; |
332 | |
333 | if (visit == PreVisit && isDiscontinuous) |
334 | { |
335 | mNestedDiscont++; |
336 | } |
337 | else if (visit == PostVisit && isDiscontinuous) |
338 | { |
339 | mNestedDiscont--; |
340 | } |
341 | |
342 | return true; |
343 | } |
344 | |
345 | bool visitAggregate(Visit visit, TIntermAggregate *node) override |
346 | { |
347 | switch (node->getOp()) |
348 | { |
349 | case EOpFunctionCall: |
350 | if (visit == PreVisit && node->isUserDefined() && mNestedDiscont > 0) |
351 | { |
352 | size_t calleeIndex = mDag.findIndex(node); |
353 | ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex); |
354 | |
355 | (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true; |
356 | } |
357 | break; |
358 | default: |
359 | break; |
360 | } |
361 | return true; |
362 | } |
363 | |
364 | private: |
365 | MetadataList *mMetadataList; |
366 | ASTMetadataHLSL *mMetadata; |
367 | size_t mIndex; |
368 | const CallDAG &mDag; |
369 | |
370 | int mNestedDiscont; |
371 | }; |
372 | |
373 | } |
374 | |
375 | bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermSelection *node) |
376 | { |
377 | return mControlFlowsContainingGradient.count(node) > 0; |
378 | } |
379 | |
380 | bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node) |
381 | { |
382 | return mControlFlowsContainingGradient.count(node) > 0; |
383 | } |
384 | |
385 | bool ASTMetadataHLSL::hasDiscontinuousLoop(TIntermSelection *node) |
386 | { |
387 | return mIfsContainingDiscontinuousLoop.count(node) > 0; |
388 | } |
389 | |
390 | MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag) |
391 | { |
392 | MetadataList metadataList(callDag.size()); |
393 | |
394 | // Compute all the information related to when gradient operations are used. |
395 | // We want to know for each function and control flow operation if they have |
396 | // a gradient operation in their call graph (shortened to "using a gradient" |
397 | // in the rest of the file). |
398 | // |
399 | // This computation is logically split in three steps: |
400 | // 1 - For each function compute if it uses a gradient in its body, ignoring |
401 | // calls to other user-defined functions. |
402 | // 2 - For each function determine if it uses a gradient in its call graph, |
403 | // using the result of step 1 and the CallDAG to know its callees. |
404 | // 3 - For each control flow statement of each function, check if it uses a |
405 | // gradient in the function's body, or if it calls a user-defined function that |
406 | // uses a gradient. |
407 | // |
408 | // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3 |
409 | // for leaves first, then going down the tree. This is correct because 1 doesn't |
410 | // depend on other functions, and 2 and 3 depend only on callees. |
411 | for (size_t i = 0; i < callDag.size(); i++) |
412 | { |
413 | PullGradient pull(&metadataList, i, callDag); |
414 | pull.traverse(callDag.getRecordFromIndex(i).node); |
415 | } |
416 | |
417 | // Compute which loops are discontinuous and which function are called in |
418 | // these loops. The same way computing gradient usage is a "pull" process, |
419 | // computing "bing used in a discont. loop" is a push process. However we also |
420 | // need to know what ifs have a discontinuous loop inside so we do the same type |
421 | // of callgraph analysis as for the gradient. |
422 | |
423 | // First compute which loops are discontinuous (no specific order) and pull |
424 | // the ifs and functions using a discontinuous loop. |
425 | for (size_t i = 0; i < callDag.size(); i++) |
426 | { |
427 | PullComputeDiscontinuousLoops pull(&metadataList, i, callDag); |
428 | pull.traverse(callDag.getRecordFromIndex(i).node); |
429 | } |
430 | |
431 | // Then push the information to callees, either from the a local discontinuous |
432 | // loop or from the caller being called in a discontinuous loop already |
433 | for (size_t i = callDag.size(); i-- > 0;) |
434 | { |
435 | PushDiscontinuousLoops push(&metadataList, i, callDag); |
436 | push.traverse(callDag.getRecordFromIndex(i).node); |
437 | } |
438 | |
439 | // We create "Lod0" version of functions with the gradient operations replaced |
440 | // by non-gradient operations so that the D3D compiler is happier with discont |
441 | // loops. |
442 | for (auto &metadata : metadataList) |
443 | { |
444 | metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient; |
445 | } |
446 | |
447 | return metadataList; |
448 | } |
449 | |