1//
2// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6
7// Analysis of the AST needed for HLSL generation
8
9#include "compiler/translator/ASTMetadataHLSL.h"
10
11#include "compiler/translator/CallDAG.h"
12#include "compiler/translator/SymbolTable.h"
13
14namespace
15{
16
17// Class used to traverse the AST of a function definition, checking if the
18// function uses a gradient, and writing the set of control flow using gradients.
19// It assumes that the analysis has already been made for the function's
20// callees.
21class PullGradient : public TIntermTraverser
22{
23 public:
24 PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
25 : TIntermTraverser(true, false, true),
26 mMetadataList(metadataList),
27 mMetadata(&(*metadataList)[index]),
28 mIndex(index),
29 mDag(dag)
30 {
31 ASSERT(index < metadataList->size());
32 }
33
34 void traverse(TIntermAggregate *node)
35 {
36 node->traverse(this);
37 ASSERT(mParents.empty());
38 }
39
40 // Called when a gradient operation or a call to a function using a gradient is found.
41 void onGradient()
42 {
43 mMetadata->mUsesGradient = true;
44 // Mark the latest control flow as using a gradient.
45 if (!mParents.empty())
46 {
47 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
48 }
49 }
50
51 void visitControlFlow(Visit visit, TIntermNode *node)
52 {
53 if (visit == PreVisit)
54 {
55 mParents.push_back(node);
56 }
57 else if (visit == PostVisit)
58 {
59 ASSERT(mParents.back() == node);
60 mParents.pop_back();
61 // A control flow's using a gradient means its parents are too.
62 if (mMetadata->mControlFlowsContainingGradient.count(node)> 0 && !mParents.empty())
63 {
64 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
65 }
66 }
67 }
68
69 bool visitLoop(Visit visit, TIntermLoop *loop) override
70 {
71 visitControlFlow(visit, loop);
72 return true;
73 }
74
75 bool visitSelection(Visit visit, TIntermSelection *selection) override
76 {
77 visitControlFlow(visit, selection);
78 return true;
79 }
80
81 bool visitUnary(Visit visit, TIntermUnary *node) override
82 {
83 if (visit == PreVisit)
84 {
85 switch (node->getOp())
86 {
87 case EOpDFdx:
88 case EOpDFdy:
89 onGradient();
90 default:
91 break;
92 }
93 }
94
95 return true;
96 }
97
98 bool visitAggregate(Visit visit, TIntermAggregate *node) override
99 {
100 if (visit == PreVisit)
101 {
102 if (node->getOp() == EOpFunctionCall)
103 {
104 if (node->isUserDefined())
105 {
106 size_t calleeIndex = mDag.findIndex(node);
107 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
108
109 if ((*mMetadataList)[calleeIndex].mUsesGradient) {
110 onGradient();
111 }
112 }
113 else
114 {
115 TString name = TFunction::unmangleName(node->getName());
116
117 if (name == "texture2D" ||
118 name == "texture2DProj" ||
119 name == "textureCube")
120 {
121 onGradient();
122 }
123 }
124 }
125 }
126
127 return true;
128 }
129
130 private:
131 MetadataList *mMetadataList;
132 ASTMetadataHLSL *mMetadata;
133 size_t mIndex;
134 const CallDAG &mDag;
135
136 // Contains a stack of the control flow nodes that are parents of the node being
137 // currently visited. It is used to mark control flows using a gradient.
138 std::vector<TIntermNode*> mParents;
139};
140
141// Traverses the AST of a function definition, assuming it has already been used to
142// traverse the callees of that function; computes the discontinuous loops and the if
143// statements that contain a discontinuous loop in their call graph.
144class PullComputeDiscontinuousLoops : public TIntermTraverser
145{
146 public:
147 PullComputeDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
148 : TIntermTraverser(true, false, true),
149 mMetadataList(metadataList),
150 mMetadata(&(*metadataList)[index]),
151 mIndex(index),
152 mDag(dag)
153 {
154 }
155
156 void traverse(TIntermAggregate *node)
157 {
158 node->traverse(this);
159 ASSERT(mLoopsAndSwitches.empty());
160 ASSERT(mIfs.empty());
161 }
162
163 // Called when a discontinuous loop or a call to a function with a discontinuous loop
164 // in its call graph is found.
165 void onDiscontinuousLoop()
166 {
167 mMetadata->mHasDiscontinuousLoopInCallGraph = true;
168 // Mark the latest if as using a discontinuous loop.
169 if (!mIfs.empty())
170 {
171 mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
172 }
173 }
174
175 bool visitLoop(Visit visit, TIntermLoop *loop) override
176 {
177 if (visit == PreVisit)
178 {
179 mLoopsAndSwitches.push_back(loop);
180 }
181 else if (visit == PostVisit)
182 {
183 ASSERT(mLoopsAndSwitches.back() == loop);
184 mLoopsAndSwitches.pop_back();
185 }
186
187 return true;
188 }
189
190 bool visitSelection(Visit visit, TIntermSelection *node) override
191 {
192 if (visit == PreVisit)
193 {
194 mIfs.push_back(node);
195 }
196 else if (visit == PostVisit)
197 {
198 ASSERT(mIfs.back() == node);
199 mIfs.pop_back();
200 // An if using a discontinuous loop means its parents ifs are also discontinuous.
201 if (mMetadata->mIfsContainingDiscontinuousLoop.count(node) > 0 && !mIfs.empty())
202 {
203 mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
204 }
205 }
206
207 return true;
208 }
209
210 bool visitBranch(Visit visit, TIntermBranch *node) override
211 {
212 if (visit == PreVisit)
213 {
214 switch (node->getFlowOp())
215 {
216 case EOpBreak:
217 {
218 ASSERT(!mLoopsAndSwitches.empty());
219 TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
220 if (loop != nullptr)
221 {
222 mMetadata->mDiscontinuousLoops.insert(loop);
223 onDiscontinuousLoop();
224 }
225 }
226 break;
227 case EOpContinue:
228 {
229 ASSERT(!mLoopsAndSwitches.empty());
230 TIntermLoop *loop = nullptr;
231 size_t i = mLoopsAndSwitches.size();
232 while (loop == nullptr && i > 0)
233 {
234 --i;
235 loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
236 }
237 ASSERT(loop != nullptr);
238 mMetadata->mDiscontinuousLoops.insert(loop);
239 onDiscontinuousLoop();
240 }
241 break;
242 case EOpKill:
243 case EOpReturn:
244 // A return or discard jumps out of all the enclosing loops
245 if (!mLoopsAndSwitches.empty())
246 {
247 for (TIntermNode* node : mLoopsAndSwitches)
248 {
249 TIntermLoop *loop = node->getAsLoopNode();
250 if (loop)
251 {
252 mMetadata->mDiscontinuousLoops.insert(loop);
253 }
254 }
255 onDiscontinuousLoop();
256 }
257 break;
258 default:
259 UNREACHABLE();
260 }
261 }
262
263 return true;
264 }
265
266 bool visitAggregate(Visit visit, TIntermAggregate *node) override
267 {
268 if (visit == PreVisit && node->getOp() == EOpFunctionCall)
269 {
270 if (node->isUserDefined())
271 {
272 size_t calleeIndex = mDag.findIndex(node);
273 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
274
275 if ((*mMetadataList)[calleeIndex].mHasDiscontinuousLoopInCallGraph)
276 {
277 onDiscontinuousLoop();
278 }
279 }
280 }
281
282 return true;
283 }
284
285 bool visitSwitch(Visit visit, TIntermSwitch *node) override
286 {
287 if (visit == PreVisit)
288 {
289 mLoopsAndSwitches.push_back(node);
290 }
291 else if (visit == PostVisit)
292 {
293 ASSERT(mLoopsAndSwitches.back() == node);
294 mLoopsAndSwitches.pop_back();
295 }
296 return true;
297 }
298
299 private:
300 MetadataList *mMetadataList;
301 ASTMetadataHLSL *mMetadata;
302 size_t mIndex;
303 const CallDAG &mDag;
304
305 std::vector<TIntermNode*> mLoopsAndSwitches;
306 std::vector<TIntermSelection*> mIfs;
307};
308
309// Tags all the functions called in a discontinuous loop
310class PushDiscontinuousLoops : public TIntermTraverser
311{
312 public:
313 PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
314 : TIntermTraverser(true, true, true),
315 mMetadataList(metadataList),
316 mMetadata(&(*metadataList)[index]),
317 mIndex(index),
318 mDag(dag),
319 mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
320 {
321 }
322
323 void traverse(TIntermAggregate *node)
324 {
325 node->traverse(this);
326 ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
327 }
328
329 bool visitLoop(Visit visit, TIntermLoop *loop) override
330 {
331 bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
332
333 if (visit == PreVisit && isDiscontinuous)
334 {
335 mNestedDiscont++;
336 }
337 else if (visit == PostVisit && isDiscontinuous)
338 {
339 mNestedDiscont--;
340 }
341
342 return true;
343 }
344
345 bool visitAggregate(Visit visit, TIntermAggregate *node) override
346 {
347 switch (node->getOp())
348 {
349 case EOpFunctionCall:
350 if (visit == PreVisit && node->isUserDefined() && mNestedDiscont > 0)
351 {
352 size_t calleeIndex = mDag.findIndex(node);
353 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
354
355 (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
356 }
357 break;
358 default:
359 break;
360 }
361 return true;
362 }
363
364 private:
365 MetadataList *mMetadataList;
366 ASTMetadataHLSL *mMetadata;
367 size_t mIndex;
368 const CallDAG &mDag;
369
370 int mNestedDiscont;
371};
372
373}
374
375bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermSelection *node)
376{
377 return mControlFlowsContainingGradient.count(node) > 0;
378}
379
380bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
381{
382 return mControlFlowsContainingGradient.count(node) > 0;
383}
384
385bool ASTMetadataHLSL::hasDiscontinuousLoop(TIntermSelection *node)
386{
387 return mIfsContainingDiscontinuousLoop.count(node) > 0;
388}
389
390MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
391{
392 MetadataList metadataList(callDag.size());
393
394 // Compute all the information related to when gradient operations are used.
395 // We want to know for each function and control flow operation if they have
396 // a gradient operation in their call graph (shortened to "using a gradient"
397 // in the rest of the file).
398 //
399 // This computation is logically split in three steps:
400 // 1 - For each function compute if it uses a gradient in its body, ignoring
401 // calls to other user-defined functions.
402 // 2 - For each function determine if it uses a gradient in its call graph,
403 // using the result of step 1 and the CallDAG to know its callees.
404 // 3 - For each control flow statement of each function, check if it uses a
405 // gradient in the function's body, or if it calls a user-defined function that
406 // uses a gradient.
407 //
408 // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
409 // for leaves first, then going down the tree. This is correct because 1 doesn't
410 // depend on other functions, and 2 and 3 depend only on callees.
411 for (size_t i = 0; i < callDag.size(); i++)
412 {
413 PullGradient pull(&metadataList, i, callDag);
414 pull.traverse(callDag.getRecordFromIndex(i).node);
415 }
416
417 // Compute which loops are discontinuous and which function are called in
418 // these loops. The same way computing gradient usage is a "pull" process,
419 // computing "bing used in a discont. loop" is a push process. However we also
420 // need to know what ifs have a discontinuous loop inside so we do the same type
421 // of callgraph analysis as for the gradient.
422
423 // First compute which loops are discontinuous (no specific order) and pull
424 // the ifs and functions using a discontinuous loop.
425 for (size_t i = 0; i < callDag.size(); i++)
426 {
427 PullComputeDiscontinuousLoops pull(&metadataList, i, callDag);
428 pull.traverse(callDag.getRecordFromIndex(i).node);
429 }
430
431 // Then push the information to callees, either from the a local discontinuous
432 // loop or from the caller being called in a discontinuous loop already
433 for (size_t i = callDag.size(); i-- > 0;)
434 {
435 PushDiscontinuousLoops push(&metadataList, i, callDag);
436 push.traverse(callDag.getRecordFromIndex(i).node);
437 }
438
439 // We create "Lod0" version of functions with the gradient operations replaced
440 // by non-gradient operations so that the D3D compiler is happier with discont
441 // loops.
442 for (auto &metadata : metadataList)
443 {
444 metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
445 }
446
447 return metadataList;
448}
449