1//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Instrumentation-based profile-guided optimization
10//
11//===----------------------------------------------------------------------===//
12
13#include "CodeGenPGO.h"
14#include "CodeGenFunction.h"
15#include "CoverageMappingGen.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
18#include "llvm/IR/Intrinsics.h"
19#include "llvm/IR/MDBuilder.h"
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/Endian.h"
22#include "llvm/Support/FileSystem.h"
23#include "llvm/Support/MD5.h"
24
25static llvm::cl::opt<bool>
26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27 llvm::cl::desc("Enable value profiling"),
28 llvm::cl::Hidden, llvm::cl::init(false));
29
30using namespace clang;
31using namespace CodeGen;
32
33void CodeGenPGO::setFuncName(StringRef Name,
34 llvm::GlobalValue::LinkageTypes Linkage) {
35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36 FuncName = llvm::getPGOFuncName(
37 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39
40 // If we're generating a profile, create a variable for the name.
41 if (CGM.getCodeGenOpts().hasProfileClangInstr())
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43}
44
45void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46 setFuncName(Fn->getName(), Fn->getLinkage());
47 // Create PGOFuncName meta data.
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49}
50
51/// The version of the PGO hash algorithm.
52enum PGOHashVersion : unsigned {
53 PGO_HASH_V1,
54 PGO_HASH_V2,
55 PGO_HASH_V3,
56
57 // Keep this set to the latest hash version.
58 PGO_HASH_LATEST = PGO_HASH_V3
59};
60
61namespace {
62/// Stable hasher for PGO region counters.
63///
64/// PGOHash produces a stable hash of a given function's control flow.
65///
66/// Changing the output of this hash will invalidate all previously generated
67/// profiles -- i.e., don't do it.
68///
69/// \note When this hash does eventually change (years?), we still need to
70/// support old hashes. We'll need to pull in the version number from the
71/// profile data format and use the matching hash function.
72class PGOHash {
73 uint64_t Working;
74 unsigned Count;
75 PGOHashVersion HashVersion;
76 llvm::MD5 MD5;
77
78 static const int NumBitsPerType = 6;
79 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80 static const unsigned TooBig = 1u << NumBitsPerType;
81
82public:
83 /// Hash values for AST nodes.
84 ///
85 /// Distinct values for AST nodes that have region counters attached.
86 ///
87 /// These values must be stable. All new members must be added at the end,
88 /// and no members should be removed. Changing the enumeration value for an
89 /// AST node will affect the hash of every function that contains that node.
90 enum HashType : unsigned char {
91 None = 0,
92 LabelStmt = 1,
93 WhileStmt,
94 DoStmt,
95 ForStmt,
96 CXXForRangeStmt,
97 ObjCForCollectionStmt,
98 SwitchStmt,
99 CaseStmt,
100 DefaultStmt,
101 IfStmt,
102 CXXTryStmt,
103 CXXCatchStmt,
104 ConditionalOperator,
105 BinaryOperatorLAnd,
106 BinaryOperatorLOr,
107 BinaryConditionalOperator,
108 // The preceding values are available with PGO_HASH_V1.
109
110 EndOfScope,
111 IfThenBranch,
112 IfElseBranch,
113 GotoStmt,
114 IndirectGotoStmt,
115 BreakStmt,
116 ContinueStmt,
117 ReturnStmt,
118 ThrowExpr,
119 UnaryOperatorLNot,
120 BinaryOperatorLT,
121 BinaryOperatorGT,
122 BinaryOperatorLE,
123 BinaryOperatorGE,
124 BinaryOperatorEQ,
125 BinaryOperatorNE,
126 // The preceding values are available since PGO_HASH_V2.
127
128 // Keep this last. It's for the static assert that follows.
129 LastHashType
130 };
131 static_assert(LastHashType <= TooBig, "Too many types in HashType");
132
133 PGOHash(PGOHashVersion HashVersion)
134 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135 void combine(HashType Type);
136 uint64_t finalize();
137 PGOHashVersion getHashVersion() const { return HashVersion; }
138};
139const int PGOHash::NumBitsPerType;
140const unsigned PGOHash::NumTypesPerWord;
141const unsigned PGOHash::TooBig;
142
143/// Get the PGO hash version used in the given indexed profile.
144static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145 CodeGenModule &CGM) {
146 if (PGOReader->getVersion() <= 4)
147 return PGO_HASH_V1;
148 if (PGOReader->getVersion() <= 5)
149 return PGO_HASH_V2;
150 return PGO_HASH_V3;
151}
152
153/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155 using Base = RecursiveASTVisitor<MapRegionCounters>;
156
157 /// The next counter value to assign.
158 unsigned NextCounter;
159 /// The function hash.
160 PGOHash Hash;
161 /// The map of statements to counters.
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163 /// The profile version.
164 uint64_t ProfileVersion;
165
166 MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
167 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
168 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
169 ProfileVersion(ProfileVersion) {}
170
171 // Blocks and lambdas are handled as separate functions, so we need not
172 // traverse them in the parent context.
173 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
174 bool TraverseLambdaExpr(LambdaExpr *LE) {
175 // Traverse the captures, but not the body.
176 for (auto C : zip(LE->captures(), LE->capture_inits()))
177 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
178 return true;
179 }
180 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
181
182 bool VisitDecl(const Decl *D) {
183 switch (D->getKind()) {
184 default:
185 break;
186 case Decl::Function:
187 case Decl::CXXMethod:
188 case Decl::CXXConstructor:
189 case Decl::CXXDestructor:
190 case Decl::CXXConversion:
191 case Decl::ObjCMethod:
192 case Decl::Block:
193 case Decl::Captured:
194 CounterMap[D->getBody()] = NextCounter++;
195 break;
196 }
197 return true;
198 }
199
200 /// If \p S gets a fresh counter, update the counter mappings. Return the
201 /// V1 hash of \p S.
202 PGOHash::HashType updateCounterMappings(Stmt *S) {
203 auto Type = getHashType(PGO_HASH_V1, S);
204 if (Type != PGOHash::None)
205 CounterMap[S] = NextCounter++;
206 return Type;
207 }
208
209 /// The RHS of all logical operators gets a fresh counter in order to count
210 /// how many times the RHS evaluates to true or false, depending on the
211 /// semantics of the operator. This is only valid for ">= v7" of the profile
212 /// version so that we facilitate backward compatibility.
213 bool VisitBinaryOperator(BinaryOperator *S) {
214 if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
215 if (S->isLogicalOp() &&
216 CodeGenFunction::isInstrumentedCondition(S->getRHS()))
217 CounterMap[S->getRHS()] = NextCounter++;
218 return Base::VisitBinaryOperator(S);
219 }
220
221 /// Include \p S in the function hash.
222 bool VisitStmt(Stmt *S) {
223 auto Type = updateCounterMappings(S);
224 if (Hash.getHashVersion() != PGO_HASH_V1)
225 Type = getHashType(Hash.getHashVersion(), S);
226 if (Type != PGOHash::None)
227 Hash.combine(Type);
228 return true;
229 }
230
231 bool TraverseIfStmt(IfStmt *If) {
232 // If we used the V1 hash, use the default traversal.
233 if (Hash.getHashVersion() == PGO_HASH_V1)
234 return Base::TraverseIfStmt(If);
235
236 // Otherwise, keep track of which branch we're in while traversing.
237 VisitStmt(If);
238 for (Stmt *CS : If->children()) {
239 if (!CS)
240 continue;
241 if (CS == If->getThen())
242 Hash.combine(PGOHash::IfThenBranch);
243 else if (CS == If->getElse())
244 Hash.combine(PGOHash::IfElseBranch);
245 TraverseStmt(CS);
246 }
247 Hash.combine(PGOHash::EndOfScope);
248 return true;
249 }
250
251// If the statement type \p N is nestable, and its nesting impacts profile
252// stability, define a custom traversal which tracks the end of the statement
253// in the hash (provided we're not using the V1 hash).
254#define DEFINE_NESTABLE_TRAVERSAL(N) \
255 bool Traverse##N(N *S) { \
256 Base::Traverse##N(S); \
257 if (Hash.getHashVersion() != PGO_HASH_V1) \
258 Hash.combine(PGOHash::EndOfScope); \
259 return true; \
260 }
261
262 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
263 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
264 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
265 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
266 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
267 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
268 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
269
270 /// Get version \p HashVersion of the PGO hash for \p S.
271 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
272 switch (S->getStmtClass()) {
273 default:
274 break;
275 case Stmt::LabelStmtClass:
276 return PGOHash::LabelStmt;
277 case Stmt::WhileStmtClass:
278 return PGOHash::WhileStmt;
279 case Stmt::DoStmtClass:
280 return PGOHash::DoStmt;
281 case Stmt::ForStmtClass:
282 return PGOHash::ForStmt;
283 case Stmt::CXXForRangeStmtClass:
284 return PGOHash::CXXForRangeStmt;
285 case Stmt::ObjCForCollectionStmtClass:
286 return PGOHash::ObjCForCollectionStmt;
287 case Stmt::SwitchStmtClass:
288 return PGOHash::SwitchStmt;
289 case Stmt::CaseStmtClass:
290 return PGOHash::CaseStmt;
291 case Stmt::DefaultStmtClass:
292 return PGOHash::DefaultStmt;
293 case Stmt::IfStmtClass:
294 return PGOHash::IfStmt;
295 case Stmt::CXXTryStmtClass:
296 return PGOHash::CXXTryStmt;
297 case Stmt::CXXCatchStmtClass:
298 return PGOHash::CXXCatchStmt;
299 case Stmt::ConditionalOperatorClass:
300 return PGOHash::ConditionalOperator;
301 case Stmt::BinaryConditionalOperatorClass:
302 return PGOHash::BinaryConditionalOperator;
303 case Stmt::BinaryOperatorClass: {
304 const BinaryOperator *BO = cast<BinaryOperator>(S);
305 if (BO->getOpcode() == BO_LAnd)
306 return PGOHash::BinaryOperatorLAnd;
307 if (BO->getOpcode() == BO_LOr)
308 return PGOHash::BinaryOperatorLOr;
309 if (HashVersion >= PGO_HASH_V2) {
310 switch (BO->getOpcode()) {
311 default:
312 break;
313 case BO_LT:
314 return PGOHash::BinaryOperatorLT;
315 case BO_GT:
316 return PGOHash::BinaryOperatorGT;
317 case BO_LE:
318 return PGOHash::BinaryOperatorLE;
319 case BO_GE:
320 return PGOHash::BinaryOperatorGE;
321 case BO_EQ:
322 return PGOHash::BinaryOperatorEQ;
323 case BO_NE:
324 return PGOHash::BinaryOperatorNE;
325 }
326 }
327 break;
328 }
329 }
330
331 if (HashVersion >= PGO_HASH_V2) {
332 switch (S->getStmtClass()) {
333 default:
334 break;
335 case Stmt::GotoStmtClass:
336 return PGOHash::GotoStmt;
337 case Stmt::IndirectGotoStmtClass:
338 return PGOHash::IndirectGotoStmt;
339 case Stmt::BreakStmtClass:
340 return PGOHash::BreakStmt;
341 case Stmt::ContinueStmtClass:
342 return PGOHash::ContinueStmt;
343 case Stmt::ReturnStmtClass:
344 return PGOHash::ReturnStmt;
345 case Stmt::CXXThrowExprClass:
346 return PGOHash::ThrowExpr;
347 case Stmt::UnaryOperatorClass: {
348 const UnaryOperator *UO = cast<UnaryOperator>(S);
349 if (UO->getOpcode() == UO_LNot)
350 return PGOHash::UnaryOperatorLNot;
351 break;
352 }
353 }
354 }
355
356 return PGOHash::None;
357 }
358};
359
360/// A StmtVisitor that propagates the raw counts through the AST and
361/// records the count at statements where the value may change.
362struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
363 /// PGO state.
364 CodeGenPGO &PGO;
365
366 /// A flag that is set when the current count should be recorded on the
367 /// next statement, such as at the exit of a loop.
368 bool RecordNextStmtCount;
369
370 /// The count at the current location in the traversal.
371 uint64_t CurrentCount;
372
373 /// The map of statements to count values.
374 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
375
376 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
377 struct BreakContinue {
378 uint64_t BreakCount;
379 uint64_t ContinueCount;
380 BreakContinue() : BreakCount(0), ContinueCount(0) {}
381 };
382 SmallVector<BreakContinue, 8> BreakContinueStack;
383
384 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
385 CodeGenPGO &PGO)
386 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
387
388 void RecordStmtCount(const Stmt *S) {
389 if (RecordNextStmtCount) {
390 CountMap[S] = CurrentCount;
391 RecordNextStmtCount = false;
392 }
393 }
394
395 /// Set and return the current count.
396 uint64_t setCount(uint64_t Count) {
397 CurrentCount = Count;
398 return Count;
399 }
400
401 void VisitStmt(const Stmt *S) {
402 RecordStmtCount(S);
403 for (const Stmt *Child : S->children())
404 if (Child)
405 this->Visit(Child);
406 }
407
408 void VisitFunctionDecl(const FunctionDecl *D) {
409 // Counter tracks entry to the function body.
410 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411 CountMap[D->getBody()] = BodyCount;
412 Visit(D->getBody());
413 }
414
415 // Skip lambda expressions. We visit these as FunctionDecls when we're
416 // generating them and aren't interested in the body when generating a
417 // parent context.
418 void VisitLambdaExpr(const LambdaExpr *LE) {}
419
420 void VisitCapturedDecl(const CapturedDecl *D) {
421 // Counter tracks entry to the capture body.
422 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
423 CountMap[D->getBody()] = BodyCount;
424 Visit(D->getBody());
425 }
426
427 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
428 // Counter tracks entry to the method body.
429 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
430 CountMap[D->getBody()] = BodyCount;
431 Visit(D->getBody());
432 }
433
434 void VisitBlockDecl(const BlockDecl *D) {
435 // Counter tracks entry to the block body.
436 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
437 CountMap[D->getBody()] = BodyCount;
438 Visit(D->getBody());
439 }
440
441 void VisitReturnStmt(const ReturnStmt *S) {
442 RecordStmtCount(S);
443 if (S->getRetValue())
444 Visit(S->getRetValue());
445 CurrentCount = 0;
446 RecordNextStmtCount = true;
447 }
448
449 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
450 RecordStmtCount(E);
451 if (E->getSubExpr())
452 Visit(E->getSubExpr());
453 CurrentCount = 0;
454 RecordNextStmtCount = true;
455 }
456
457 void VisitGotoStmt(const GotoStmt *S) {
458 RecordStmtCount(S);
459 CurrentCount = 0;
460 RecordNextStmtCount = true;
461 }
462
463 void VisitLabelStmt(const LabelStmt *S) {
464 RecordNextStmtCount = false;
465 // Counter tracks the block following the label.
466 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
467 CountMap[S] = BlockCount;
468 Visit(S->getSubStmt());
469 }
470
471 void VisitBreakStmt(const BreakStmt *S) {
472 RecordStmtCount(S);
473 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
474 BreakContinueStack.back().BreakCount += CurrentCount;
475 CurrentCount = 0;
476 RecordNextStmtCount = true;
477 }
478
479 void VisitContinueStmt(const ContinueStmt *S) {
480 RecordStmtCount(S);
481 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
482 BreakContinueStack.back().ContinueCount += CurrentCount;
483 CurrentCount = 0;
484 RecordNextStmtCount = true;
485 }
486
487 void VisitWhileStmt(const WhileStmt *S) {
488 RecordStmtCount(S);
489 uint64_t ParentCount = CurrentCount;
490
491 BreakContinueStack.push_back(BreakContinue());
492 // Visit the body region first so the break/continue adjustments can be
493 // included when visiting the condition.
494 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
495 CountMap[S->getBody()] = CurrentCount;
496 Visit(S->getBody());
497 uint64_t BackedgeCount = CurrentCount;
498
499 // ...then go back and propagate counts through the condition. The count
500 // at the start of the condition is the sum of the incoming edges,
501 // the backedge from the end of the loop body, and the edges from
502 // continue statements.
503 BreakContinue BC = BreakContinueStack.pop_back_val();
504 uint64_t CondCount =
505 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
506 CountMap[S->getCond()] = CondCount;
507 Visit(S->getCond());
508 setCount(BC.BreakCount + CondCount - BodyCount);
509 RecordNextStmtCount = true;
510 }
511
512 void VisitDoStmt(const DoStmt *S) {
513 RecordStmtCount(S);
514 uint64_t LoopCount = PGO.getRegionCount(S);
515
516 BreakContinueStack.push_back(BreakContinue());
517 // The count doesn't include the fallthrough from the parent scope. Add it.
518 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
519 CountMap[S->getBody()] = BodyCount;
520 Visit(S->getBody());
521 uint64_t BackedgeCount = CurrentCount;
522
523 BreakContinue BC = BreakContinueStack.pop_back_val();
524 // The count at the start of the condition is equal to the count at the
525 // end of the body, plus any continues.
526 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
527 CountMap[S->getCond()] = CondCount;
528 Visit(S->getCond());
529 setCount(BC.BreakCount + CondCount - LoopCount);
530 RecordNextStmtCount = true;
531 }
532
533 void VisitForStmt(const ForStmt *S) {
534 RecordStmtCount(S);
535 if (S->getInit())
536 Visit(S->getInit());
537
538 uint64_t ParentCount = CurrentCount;
539
540 BreakContinueStack.push_back(BreakContinue());
541 // Visit the body region first. (This is basically the same as a while
542 // loop; see further comments in VisitWhileStmt.)
543 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
544 CountMap[S->getBody()] = BodyCount;
545 Visit(S->getBody());
546 uint64_t BackedgeCount = CurrentCount;
547 BreakContinue BC = BreakContinueStack.pop_back_val();
548
549 // The increment is essentially part of the body but it needs to include
550 // the count for all the continue statements.
551 if (S->getInc()) {
552 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
553 CountMap[S->getInc()] = IncCount;
554 Visit(S->getInc());
555 }
556
557 // ...then go back and propagate counts through the condition.
558 uint64_t CondCount =
559 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
560 if (S->getCond()) {
561 CountMap[S->getCond()] = CondCount;
562 Visit(S->getCond());
563 }
564 setCount(BC.BreakCount + CondCount - BodyCount);
565 RecordNextStmtCount = true;
566 }
567
568 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
569 RecordStmtCount(S);
570 if (S->getInit())
571 Visit(S->getInit());
572 Visit(S->getLoopVarStmt());
573 Visit(S->getRangeStmt());
574 Visit(S->getBeginStmt());
575 Visit(S->getEndStmt());
576
577 uint64_t ParentCount = CurrentCount;
578 BreakContinueStack.push_back(BreakContinue());
579 // Visit the body region first. (This is basically the same as a while
580 // loop; see further comments in VisitWhileStmt.)
581 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
582 CountMap[S->getBody()] = BodyCount;
583 Visit(S->getBody());
584 uint64_t BackedgeCount = CurrentCount;
585 BreakContinue BC = BreakContinueStack.pop_back_val();
586
587 // The increment is essentially part of the body but it needs to include
588 // the count for all the continue statements.
589 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
590 CountMap[S->getInc()] = IncCount;
591 Visit(S->getInc());
592
593 // ...then go back and propagate counts through the condition.
594 uint64_t CondCount =
595 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
596 CountMap[S->getCond()] = CondCount;
597 Visit(S->getCond());
598 setCount(BC.BreakCount + CondCount - BodyCount);
599 RecordNextStmtCount = true;
600 }
601
602 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
603 RecordStmtCount(S);
604 Visit(S->getElement());
605 uint64_t ParentCount = CurrentCount;
606 BreakContinueStack.push_back(BreakContinue());
607 // Counter tracks the body of the loop.
608 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
609 CountMap[S->getBody()] = BodyCount;
610 Visit(S->getBody());
611 uint64_t BackedgeCount = CurrentCount;
612 BreakContinue BC = BreakContinueStack.pop_back_val();
613
614 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
615 BodyCount);
616 RecordNextStmtCount = true;
617 }
618
619 void VisitSwitchStmt(const SwitchStmt *S) {
620 RecordStmtCount(S);
621 if (S->getInit())
622 Visit(S->getInit());
623 Visit(S->getCond());
624 CurrentCount = 0;
625 BreakContinueStack.push_back(BreakContinue());
626 Visit(S->getBody());
627 // If the switch is inside a loop, add the continue counts.
628 BreakContinue BC = BreakContinueStack.pop_back_val();
629 if (!BreakContinueStack.empty())
630 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
631 // Counter tracks the exit block of the switch.
632 setCount(PGO.getRegionCount(S));
633 RecordNextStmtCount = true;
634 }
635
636 void VisitSwitchCase(const SwitchCase *S) {
637 RecordNextStmtCount = false;
638 // Counter for this particular case. This counts only jumps from the
639 // switch header and does not include fallthrough from the case before
640 // this one.
641 uint64_t CaseCount = PGO.getRegionCount(S);
642 setCount(CurrentCount + CaseCount);
643 // We need the count without fallthrough in the mapping, so it's more useful
644 // for branch probabilities.
645 CountMap[S] = CaseCount;
646 RecordNextStmtCount = true;
647 Visit(S->getSubStmt());
648 }
649
650 void VisitIfStmt(const IfStmt *S) {
651 RecordStmtCount(S);
652 uint64_t ParentCount = CurrentCount;
653 if (S->getInit())
654 Visit(S->getInit());
655 Visit(S->getCond());
656
657 // Counter tracks the "then" part of an if statement. The count for
658 // the "else" part, if it exists, will be calculated from this counter.
659 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
660 CountMap[S->getThen()] = ThenCount;
661 Visit(S->getThen());
662 uint64_t OutCount = CurrentCount;
663
664 uint64_t ElseCount = ParentCount - ThenCount;
665 if (S->getElse()) {
666 setCount(ElseCount);
667 CountMap[S->getElse()] = ElseCount;
668 Visit(S->getElse());
669 OutCount += CurrentCount;
670 } else
671 OutCount += ElseCount;
672 setCount(OutCount);
673 RecordNextStmtCount = true;
674 }
675
676 void VisitCXXTryStmt(const CXXTryStmt *S) {
677 RecordStmtCount(S);
678 Visit(S->getTryBlock());
679 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
680 Visit(S->getHandler(I));
681 // Counter tracks the continuation block of the try statement.
682 setCount(PGO.getRegionCount(S));
683 RecordNextStmtCount = true;
684 }
685
686 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
687 RecordNextStmtCount = false;
688 // Counter tracks the catch statement's handler block.
689 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
690 CountMap[S] = CatchCount;
691 Visit(S->getHandlerBlock());
692 }
693
694 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
695 RecordStmtCount(E);
696 uint64_t ParentCount = CurrentCount;
697 Visit(E->getCond());
698
699 // Counter tracks the "true" part of a conditional operator. The
700 // count in the "false" part will be calculated from this counter.
701 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
702 CountMap[E->getTrueExpr()] = TrueCount;
703 Visit(E->getTrueExpr());
704 uint64_t OutCount = CurrentCount;
705
706 uint64_t FalseCount = setCount(ParentCount - TrueCount);
707 CountMap[E->getFalseExpr()] = FalseCount;
708 Visit(E->getFalseExpr());
709 OutCount += CurrentCount;
710
711 setCount(OutCount);
712 RecordNextStmtCount = true;
713 }
714
715 void VisitBinLAnd(const BinaryOperator *E) {
716 RecordStmtCount(E);
717 uint64_t ParentCount = CurrentCount;
718 Visit(E->getLHS());
719 // Counter tracks the right hand side of a logical and operator.
720 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
721 CountMap[E->getRHS()] = RHSCount;
722 Visit(E->getRHS());
723 setCount(ParentCount + RHSCount - CurrentCount);
724 RecordNextStmtCount = true;
725 }
726
727 void VisitBinLOr(const BinaryOperator *E) {
728 RecordStmtCount(E);
729 uint64_t ParentCount = CurrentCount;
730 Visit(E->getLHS());
731 // Counter tracks the right hand side of a logical or operator.
732 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
733 CountMap[E->getRHS()] = RHSCount;
734 Visit(E->getRHS());
735 setCount(ParentCount + RHSCount - CurrentCount);
736 RecordNextStmtCount = true;
737 }
738};
739} // end anonymous namespace
740
741void PGOHash::combine(HashType Type) {
742 // Check that we never combine 0 and only have six bits.
743 assert(Type && "Hash is invalid: unexpected type 0");
744 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
745
746 // Pass through MD5 if enough work has built up.
747 if (Count && Count % NumTypesPerWord == 0) {
748 using namespace llvm::support;
749 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
750 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
751 Working = 0;
752 }
753
754 // Accumulate the current type.
755 ++Count;
756 Working = Working << NumBitsPerType | Type;
757}
758
759uint64_t PGOHash::finalize() {
760 // Use Working as the hash directly if we never used MD5.
761 if (Count <= NumTypesPerWord)
762 // No need to byte swap here, since none of the math was endian-dependent.
763 // This number will be byte-swapped as required on endianness transitions,
764 // so we will see the same value on the other side.
765 return Working;
766
767 // Check for remaining work in Working.
768 if (Working) {
769 // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
770 // is buggy because it converts a uint64_t into an array of uint8_t.
771 if (HashVersion < PGO_HASH_V3) {
772 MD5.update({(uint8_t)Working});
773 } else {
774 using namespace llvm::support;
775 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
776 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
777 }
778 }
779
780 // Finalize the MD5 and return the hash.
781 llvm::MD5::MD5Result Result;
782 MD5.final(Result);
783 return Result.low();
784}
785
786void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
787 const Decl *D = GD.getDecl();
788 if (!D->hasBody())
789 return;
790
791 // Skip CUDA/HIP kernel launch stub functions.
792 if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
793 D->hasAttr<CUDAGlobalAttr>())
794 return;
795
796 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
797 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
798 if (!InstrumentRegions && !PGOReader)
799 return;
800 if (D->isImplicit())
801 return;
802 // Constructors and destructors may be represented by several functions in IR.
803 // If so, instrument only base variant, others are implemented by delegation
804 // to the base one, it would be counted twice otherwise.
805 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
806 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
807 if (GD.getCtorType() != Ctor_Base &&
808 CodeGenFunction::IsConstructorDelegationValid(CCD))
809 return;
810 }
811 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
812 return;
813
814 CGM.ClearUnusedCoverageMapping(D);
815 if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
816 return;
817
818 setFuncName(Fn);
819
820 mapRegionCounters(D);
821 if (CGM.getCodeGenOpts().CoverageMapping)
822 emitCounterRegionMapping(D);
823 if (PGOReader) {
824 SourceManager &SM = CGM.getContext().getSourceManager();
825 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
826 computeRegionCounts(D);
827 applyFunctionAttributes(PGOReader, Fn);
828 }
829}
830
831void CodeGenPGO::mapRegionCounters(const Decl *D) {
832 // Use the latest hash version when inserting instrumentation, but use the
833 // version in the indexed profile if we're reading PGO data.
834 PGOHashVersion HashVersion = PGO_HASH_LATEST;
835 uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
836 if (auto *PGOReader = CGM.getPGOReader()) {
837 HashVersion = getPGOHashVersion(PGOReader, CGM);
838 ProfileVersion = PGOReader->getVersion();
839 }
840
841 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
842 MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
843 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
844 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
845 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
846 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
847 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
848 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
849 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
850 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
851 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
852 NumRegionCounters = Walker.NextCounter;
853 FunctionHash = Walker.Hash.finalize();
854}
855
856bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
857 if (!D->getBody())
858 return true;
859
860 // Skip host-only functions in the CUDA device compilation and device-only
861 // functions in the host compilation. Just roughly filter them out based on
862 // the function attributes. If there are effectively host-only or device-only
863 // ones, their coverage mapping may still be generated.
864 if (CGM.getLangOpts().CUDA &&
865 ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
866 !D->hasAttr<CUDAGlobalAttr>()) ||
867 (!CGM.getLangOpts().CUDAIsDevice &&
868 (D->hasAttr<CUDAGlobalAttr>() ||
869 (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
870 return true;
871
872 // Don't map the functions in system headers.
873 const auto &SM = CGM.getContext().getSourceManager();
874 auto Loc = D->getBody()->getBeginLoc();
875 return SM.isInSystemHeader(Loc);
876}
877
878void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
879 if (skipRegionMappingForDecl(D))
880 return;
881
882 std::string CoverageMapping;
883 llvm::raw_string_ostream OS(CoverageMapping);
884 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
885 CGM.getContext().getSourceManager(),
886 CGM.getLangOpts(), RegionCounterMap.get());
887 MappingGen.emitCounterMapping(D, OS);
888 OS.flush();
889
890 if (CoverageMapping.empty())
891 return;
892
893 CGM.getCoverageMapping()->addFunctionMappingRecord(
894 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
895}
896
897void
898CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
899 llvm::GlobalValue::LinkageTypes Linkage) {
900 if (skipRegionMappingForDecl(D))
901 return;
902
903 std::string CoverageMapping;
904 llvm::raw_string_ostream OS(CoverageMapping);
905 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
906 CGM.getContext().getSourceManager(),
907 CGM.getLangOpts());
908 MappingGen.emitEmptyMapping(D, OS);
909 OS.flush();
910
911 if (CoverageMapping.empty())
912 return;
913
914 setFuncName(Name, Linkage);
915 CGM.getCoverageMapping()->addFunctionMappingRecord(
916 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
917}
918
919void CodeGenPGO::computeRegionCounts(const Decl *D) {
920 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
921 ComputeRegionCounts Walker(*StmtCountMap, *this);
922 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
923 Walker.VisitFunctionDecl(FD);
924 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
925 Walker.VisitObjCMethodDecl(MD);
926 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
927 Walker.VisitBlockDecl(BD);
928 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
929 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
930}
931
932void
933CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
934 llvm::Function *Fn) {
935 if (!haveRegionCounts())
936 return;
937
938 uint64_t FunctionCount = getRegionCount(nullptr);
939 Fn->setEntryCount(FunctionCount);
940}
941
942void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
943 llvm::Value *StepV) {
944 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
945 return;
946 if (!Builder.GetInsertBlock())
947 return;
948
949 unsigned Counter = (*RegionCounterMap)[S];
950 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
951
952 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
953 Builder.getInt64(FunctionHash),
954 Builder.getInt32(NumRegionCounters),
955 Builder.getInt32(Counter), StepV};
956 if (!StepV)
957 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
958 makeArrayRef(Args, 4));
959 else
960 Builder.CreateCall(
961 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
962 makeArrayRef(Args));
963}
964
965// This method either inserts a call to the profile run-time during
966// instrumentation or puts profile data into metadata for PGO use.
967void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
968 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
969
970 if (!EnableValueProfiling)
971 return;
972
973 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
974 return;
975
976 if (isa<llvm::Constant>(ValuePtr))
977 return;
978
979 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
980 if (InstrumentValueSites && RegionCounterMap) {
981 auto BuilderInsertPoint = Builder.saveIP();
982 Builder.SetInsertPoint(ValueSite);
983 llvm::Value *Args[5] = {
984 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
985 Builder.getInt64(FunctionHash),
986 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
987 Builder.getInt32(ValueKind),
988 Builder.getInt32(NumValueSites[ValueKind]++)
989 };
990 Builder.CreateCall(
991 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
992 Builder.restoreIP(BuilderInsertPoint);
993 return;
994 }
995
996 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
997 if (PGOReader && haveRegionCounts()) {
998 // We record the top most called three functions at each call site.
999 // Profile metadata contains "VP" string identifying this metadata
1000 // as value profiling data, then a uint32_t value for the value profiling
1001 // kind, a uint64_t value for the total number of times the call is
1002 // executed, followed by the function hash and execution count (uint64_t)
1003 // pairs for each function.
1004 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1005 return;
1006
1007 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1008 (llvm::InstrProfValueKind)ValueKind,
1009 NumValueSites[ValueKind]);
1010
1011 NumValueSites[ValueKind]++;
1012 }
1013}
1014
1015void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1016 bool IsInMainFile) {
1017 CGM.getPGOStats().addVisited(IsInMainFile);
1018 RegionCounts.clear();
1019 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1020 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1021 if (auto E = RecordExpected.takeError()) {
1022 auto IPE = llvm::InstrProfError::take(std::move(E));
1023 if (IPE == llvm::instrprof_error::unknown_function)
1024 CGM.getPGOStats().addMissing(IsInMainFile);
1025 else if (IPE == llvm::instrprof_error::hash_mismatch)
1026 CGM.getPGOStats().addMismatched(IsInMainFile);
1027 else if (IPE == llvm::instrprof_error::malformed)
1028 // TODO: Consider a more specific warning for this case.
1029 CGM.getPGOStats().addMismatched(IsInMainFile);
1030 return;
1031 }
1032 ProfRecord =
1033 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1034 RegionCounts = ProfRecord->Counts;
1035}
1036
1037/// Calculate what to divide by to scale weights.
1038///
1039/// Given the maximum weight, calculate a divisor that will scale all the
1040/// weights to strictly less than UINT32_MAX.
1041static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1042 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1043}
1044
1045/// Scale an individual branch weight (and add 1).
1046///
1047/// Scale a 64-bit weight down to 32-bits using \c Scale.
1048///
1049/// According to Laplace's Rule of Succession, it is better to compute the
1050/// weight based on the count plus 1, so universally add 1 to the value.
1051///
1052/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1053/// greater than \c Weight.
1054static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1055 assert(Scale && "scale by 0?");
1056 uint64_t Scaled = Weight / Scale + 1;
1057 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1058 return Scaled;
1059}
1060
1061llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1062 uint64_t FalseCount) const {
1063 // Check for empty weights.
1064 if (!TrueCount && !FalseCount)
1065 return nullptr;
1066
1067 // Calculate how to scale down to 32-bits.
1068 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1069
1070 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1071 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1072 scaleBranchWeight(FalseCount, Scale));
1073}
1074
1075llvm::MDNode *
1076CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1077 // We need at least two elements to create meaningful weights.
1078 if (Weights.size() < 2)
1079 return nullptr;
1080
1081 // Check for empty weights.
1082 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1083 if (MaxWeight == 0)
1084 return nullptr;
1085
1086 // Calculate how to scale down to 32-bits.
1087 uint64_t Scale = calculateWeightScale(MaxWeight);
1088
1089 SmallVector<uint32_t, 16> ScaledWeights;
1090 ScaledWeights.reserve(Weights.size());
1091 for (uint64_t W : Weights)
1092 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1093
1094 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1095 return MDHelper.createBranchWeights(ScaledWeights);
1096}
1097
1098llvm::MDNode *
1099CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1100 uint64_t LoopCount) const {
1101 if (!PGO.haveRegionCounts())
1102 return nullptr;
1103 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1104 if (!CondCount || *CondCount == 0)
1105 return nullptr;
1106 return createProfileWeights(LoopCount,
1107 std::max(*CondCount, LoopCount) - LoopCount);
1108}
1109