1 | //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for working with Profiling Metadata. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "llvm/IR/ProfDataUtils.h" |
14 | #include "llvm/ADT/SmallVector.h" |
15 | #include "llvm/ADT/Twine.h" |
16 | #include "llvm/IR/Constants.h" |
17 | #include "llvm/IR/Function.h" |
18 | #include "llvm/IR/Instructions.h" |
19 | #include "llvm/IR/LLVMContext.h" |
20 | #include "llvm/IR/MDBuilder.h" |
21 | #include "llvm/IR/Metadata.h" |
22 | #include "llvm/Support/BranchProbability.h" |
23 | #include "llvm/Support/CommandLine.h" |
24 | |
25 | using namespace llvm; |
26 | |
27 | namespace { |
28 | |
29 | // MD_prof nodes have the following layout |
30 | // |
31 | // In general: |
32 | // { String name, Array of i32 } |
33 | // |
34 | // In terms of Types: |
35 | // { MDString, [i32, i32, ...]} |
36 | // |
37 | // Concretely for Branch Weights |
38 | // { "branch_weights", [i32 1, i32 10000]} |
39 | // |
40 | // We maintain some constants here to ensure that we access the branch weights |
41 | // correctly, and can change the behavior in the future if the layout changes |
42 | |
43 | // The index at which the weights vector starts |
44 | constexpr unsigned WeightsIdx = 1; |
45 | |
46 | // the minimum number of operands for MD_prof nodes with branch weights |
47 | constexpr unsigned MinBWOps = 3; |
48 | |
49 | // We may want to add support for other MD_prof types, so provide an abstraction |
50 | // for checking the metadata type. |
51 | bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { |
52 | // TODO: This routine may be simplified if MD_prof used an enum instead of a |
53 | // string to differentiate the types of MD_prof nodes. |
54 | if (!ProfData || !Name || MinOps < 2) |
55 | return false; |
56 | |
57 | unsigned NOps = ProfData->getNumOperands(); |
58 | if (NOps < MinOps) |
59 | return false; |
60 | |
61 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfData->getOperand(I: 0)); |
62 | if (!ProfDataName) |
63 | return false; |
64 | |
65 | return ProfDataName->getString().equals(RHS: Name); |
66 | } |
67 | |
68 | } // namespace |
69 | |
70 | namespace llvm { |
71 | |
72 | bool hasProfMD(const Instruction &I) { |
73 | return I.hasMetadata(KindID: LLVMContext::MD_prof); |
74 | } |
75 | |
76 | bool isBranchWeightMD(const MDNode *ProfileData) { |
77 | return isTargetMD(ProfData: ProfileData, Name: "branch_weights" , MinOps: MinBWOps); |
78 | } |
79 | |
80 | bool hasBranchWeightMD(const Instruction &I) { |
81 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
82 | return isBranchWeightMD(ProfileData); |
83 | } |
84 | |
85 | bool hasValidBranchWeightMD(const Instruction &I) { |
86 | return getValidBranchWeightMDNode(I); |
87 | } |
88 | |
89 | MDNode *getBranchWeightMDNode(const Instruction &I) { |
90 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
91 | if (!isBranchWeightMD(ProfileData)) |
92 | return nullptr; |
93 | return ProfileData; |
94 | } |
95 | |
96 | MDNode *getValidBranchWeightMDNode(const Instruction &I) { |
97 | auto *ProfileData = getBranchWeightMDNode(I); |
98 | if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) |
99 | return ProfileData; |
100 | return nullptr; |
101 | } |
102 | |
103 | void (const MDNode *ProfileData, |
104 | SmallVectorImpl<uint32_t> &Weights) { |
105 | assert(isBranchWeightMD(ProfileData) && "wrong metadata" ); |
106 | |
107 | unsigned NOps = ProfileData->getNumOperands(); |
108 | assert(WeightsIdx < NOps && "Weights Index must be less than NOps." ); |
109 | Weights.resize(N: NOps - WeightsIdx); |
110 | |
111 | for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { |
112 | ConstantInt *Weight = |
113 | mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx)); |
114 | assert(Weight && "Malformed branch_weight in MD_prof node" ); |
115 | assert(Weight->getValue().getActiveBits() <= 32 && |
116 | "Too many bits for uint32_t" ); |
117 | Weights[Idx - WeightsIdx] = Weight->getZExtValue(); |
118 | } |
119 | } |
120 | |
121 | bool (const MDNode *ProfileData, |
122 | SmallVectorImpl<uint32_t> &Weights) { |
123 | if (!isBranchWeightMD(ProfileData)) |
124 | return false; |
125 | extractFromBranchWeightMD(ProfileData, Weights); |
126 | return true; |
127 | } |
128 | |
129 | bool (const Instruction &I, |
130 | SmallVectorImpl<uint32_t> &Weights) { |
131 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
132 | return extractBranchWeights(ProfileData, Weights); |
133 | } |
134 | |
135 | bool (const Instruction &I, uint64_t &TrueVal, |
136 | uint64_t &FalseVal) { |
137 | assert((I.getOpcode() == Instruction::Br || |
138 | I.getOpcode() == Instruction::Select) && |
139 | "Looking for branch weights on something besides branch, select, or " |
140 | "switch" ); |
141 | |
142 | SmallVector<uint32_t, 2> Weights; |
143 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
144 | if (!extractBranchWeights(ProfileData, Weights)) |
145 | return false; |
146 | |
147 | if (Weights.size() > 2) |
148 | return false; |
149 | |
150 | TrueVal = Weights[0]; |
151 | FalseVal = Weights[1]; |
152 | return true; |
153 | } |
154 | |
155 | bool (const MDNode *ProfileData, uint64_t &TotalVal) { |
156 | TotalVal = 0; |
157 | if (!ProfileData) |
158 | return false; |
159 | |
160 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0)); |
161 | if (!ProfDataName) |
162 | return false; |
163 | |
164 | if (ProfDataName->getString().equals(RHS: "branch_weights" )) { |
165 | for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { |
166 | auto *V = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx)); |
167 | assert(V && "Malformed branch_weight in MD_prof node" ); |
168 | TotalVal += V->getValue().getZExtValue(); |
169 | } |
170 | return true; |
171 | } |
172 | |
173 | if (ProfDataName->getString().equals(RHS: "VP" ) && |
174 | ProfileData->getNumOperands() > 3) { |
175 | TotalVal = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: 2)) |
176 | ->getValue() |
177 | .getZExtValue(); |
178 | return true; |
179 | } |
180 | return false; |
181 | } |
182 | |
183 | bool (const Instruction &I, uint64_t &TotalVal) { |
184 | return extractProfTotalWeight(ProfileData: I.getMetadata(KindID: LLVMContext::MD_prof), TotalVal); |
185 | } |
186 | |
187 | void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) { |
188 | MDBuilder MDB(I.getContext()); |
189 | MDNode *BranchWeights = MDB.createBranchWeights(Weights); |
190 | I.setMetadata(KindID: LLVMContext::MD_prof, Node: BranchWeights); |
191 | } |
192 | |
193 | void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { |
194 | assert(T != 0 && "Caller should guarantee" ); |
195 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
196 | if (ProfileData == nullptr) |
197 | return; |
198 | |
199 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0)); |
200 | if (!ProfDataName || (!ProfDataName->getString().equals(RHS: "branch_weights" ) && |
201 | !ProfDataName->getString().equals(RHS: "VP" ))) |
202 | return; |
203 | |
204 | LLVMContext &C = I.getContext(); |
205 | |
206 | MDBuilder MDB(C); |
207 | SmallVector<Metadata *, 3> Vals; |
208 | Vals.push_back(Elt: ProfileData->getOperand(I: 0)); |
209 | APInt APS(128, S), APT(128, T); |
210 | if (ProfDataName->getString().equals(RHS: "branch_weights" ) && |
211 | ProfileData->getNumOperands() > 0) { |
212 | // Using APInt::div may be expensive, but most cases should fit 64 bits. |
213 | APInt Val(128, mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: 1)) |
214 | ->getValue() |
215 | .getZExtValue()); |
216 | Val *= APS; |
217 | Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get( |
218 | Ty: Type::getInt32Ty(C), V: Val.udiv(RHS: APT).getLimitedValue(UINT32_MAX)))); |
219 | } else if (ProfDataName->getString().equals(RHS: "VP" )) |
220 | for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) { |
221 | // The first value is the key of the value profile, which will not change. |
222 | Vals.push_back(Elt: ProfileData->getOperand(I: i)); |
223 | uint64_t Count = |
224 | mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: i + 1)) |
225 | ->getValue() |
226 | .getZExtValue(); |
227 | // Don't scale the magic number. |
228 | if (Count == NOMORE_ICP_MAGICNUM) { |
229 | Vals.push_back(Elt: ProfileData->getOperand(I: i + 1)); |
230 | continue; |
231 | } |
232 | // Using APInt::div may be expensive, but most cases should fit 64 bits. |
233 | APInt Val(128, Count); |
234 | Val *= APS; |
235 | Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get( |
236 | Ty: Type::getInt64Ty(C), V: Val.udiv(RHS: APT).getLimitedValue()))); |
237 | } |
238 | I.setMetadata(KindID: LLVMContext::MD_prof, Node: MDNode::get(Context&: C, MDs: Vals)); |
239 | } |
240 | |
241 | } // namespace llvm |
242 | |