1 | //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===// |
---|---|

2 | // |

3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |

4 | // See https://llvm.org/LICENSE.txt for license information. |

5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |

6 | // |

7 | //===----------------------------------------------------------------------===// |

8 | // |

9 | // Definition of BranchProbability shared by IR and Machine Instructions. |

10 | // |

11 | //===----------------------------------------------------------------------===// |

12 | |

13 | #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H |

14 | #define LLVM_SUPPORT_BRANCHPROBABILITY_H |

15 | |

16 | #include "llvm/Support/DataTypes.h" |

17 | #include <algorithm> |

18 | #include <cassert> |

19 | #include <climits> |

20 | #include <numeric> |

21 | |

22 | namespace llvm { |

23 | |

24 | class raw_ostream; |

25 | |

26 | // This class represents Branch Probability as a non-negative fraction that is |

27 | // no greater than 1. It uses a fixed-point-like implementation, in which the |

28 | // denominator is always a constant value (here we use 1<<31 for maximum |

29 | // precision). |

30 | class BranchProbability { |

31 | // Numerator |

32 | uint32_t N; |

33 | |

34 | // Denominator, which is a constant value. |

35 | static constexpr uint32_t D = 1u << 31; |

36 | static constexpr uint32_t UnknownN = UINT32_MAX; |

37 | |

38 | // Construct a BranchProbability with only numerator assuming the denominator |

39 | // is 1<<31. For internal use only. |

40 | explicit BranchProbability(uint32_t n) : N(n) {} |

41 | |

42 | public: |

43 | BranchProbability() : N(UnknownN) {} |

44 | BranchProbability(uint32_t Numerator, uint32_t Denominator); |

45 | |

46 | bool isZero() const { return N == 0; } |

47 | bool isUnknown() const { return N == UnknownN; } |

48 | |

49 | static BranchProbability getZero() { return BranchProbability(0); } |

50 | static BranchProbability getOne() { return BranchProbability(D); } |

51 | static BranchProbability getUnknown() { return BranchProbability(UnknownN); } |

52 | // Create a BranchProbability object with the given numerator and 1<<31 |

53 | // as denominator. |

54 | static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); } |

55 | // Create a BranchProbability object from 64-bit integers. |

56 | static BranchProbability getBranchProbability(uint64_t Numerator, |

57 | uint64_t Denominator); |

58 | |

59 | // Normalize given probabilties so that the sum of them becomes approximate |

60 | // one. |

61 | template <class ProbabilityIter> |

62 | static void normalizeProbabilities(ProbabilityIter Begin, |

63 | ProbabilityIter End); |

64 | |

65 | uint32_t getNumerator() const { return N; } |

66 | static uint32_t getDenominator() { return D; } |

67 | |

68 | // Return (1 - Probability). |

69 | BranchProbability getCompl() const { return BranchProbability(D - N); } |

70 | |

71 | raw_ostream &print(raw_ostream &OS) const; |

72 | |

73 | void dump() const; |

74 | |

75 | /// Scale a large integer. |

76 | /// |

77 | /// Scales \c Num. Guarantees full precision. Returns the floor of the |

78 | /// result. |

79 | /// |

80 | /// \return \c Num times \c this. |

81 | uint64_t scale(uint64_t Num) const; |

82 | |

83 | /// Scale a large integer by the inverse. |

84 | /// |

85 | /// Scales \c Num by the inverse of \c this. Guarantees full precision. |

86 | /// Returns the floor of the result. |

87 | /// |

88 | /// \return \c Num divided by \c this. |

89 | uint64_t scaleByInverse(uint64_t Num) const; |

90 | |

91 | BranchProbability &operator+=(BranchProbability RHS) { |

92 | assert(N != UnknownN && RHS.N != UnknownN && |

93 | "Unknown probability cannot participate in arithmetics."); |

94 | // Saturate the result in case of overflow. |

95 | N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N; |

96 | return *this; |

97 | } |

98 | |

99 | BranchProbability &operator-=(BranchProbability RHS) { |

100 | assert(N != UnknownN && RHS.N != UnknownN && |

101 | "Unknown probability cannot participate in arithmetics."); |

102 | // Saturate the result in case of underflow. |

103 | N = N < RHS.N ? 0 : N - RHS.N; |

104 | return *this; |

105 | } |

106 | |

107 | BranchProbability &operator*=(BranchProbability RHS) { |

108 | assert(N != UnknownN && RHS.N != UnknownN && |

109 | "Unknown probability cannot participate in arithmetics."); |

110 | N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D; |

111 | return *this; |

112 | } |

113 | |

114 | BranchProbability &operator*=(uint32_t RHS) { |

115 | assert(N != UnknownN && |

116 | "Unknown probability cannot participate in arithmetics."); |

117 | N = (uint64_t(N) * RHS > D) ? D : N * RHS; |

118 | return *this; |

119 | } |

120 | |

121 | BranchProbability &operator/=(BranchProbability RHS) { |

122 | assert(N != UnknownN && RHS.N != UnknownN && |

123 | "Unknown probability cannot participate in arithmetics."); |

124 | N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N; |

125 | return *this; |

126 | } |

127 | |

128 | BranchProbability &operator/=(uint32_t RHS) { |

129 | assert(N != UnknownN && |

130 | "Unknown probability cannot participate in arithmetics."); |

131 | assert(RHS > 0 && "The divider cannot be zero."); |

132 | N /= RHS; |

133 | return *this; |

134 | } |

135 | |

136 | BranchProbability operator+(BranchProbability RHS) const { |

137 | BranchProbability Prob(*this); |

138 | Prob += RHS; |

139 | return Prob; |

140 | } |

141 | |

142 | BranchProbability operator-(BranchProbability RHS) const { |

143 | BranchProbability Prob(*this); |

144 | Prob -= RHS; |

145 | return Prob; |

146 | } |

147 | |

148 | BranchProbability operator*(BranchProbability RHS) const { |

149 | BranchProbability Prob(*this); |

150 | Prob *= RHS; |

151 | return Prob; |

152 | } |

153 | |

154 | BranchProbability operator*(uint32_t RHS) const { |

155 | BranchProbability Prob(*this); |

156 | Prob *= RHS; |

157 | return Prob; |

158 | } |

159 | |

160 | BranchProbability operator/(BranchProbability RHS) const { |

161 | BranchProbability Prob(*this); |

162 | Prob /= RHS; |

163 | return Prob; |

164 | } |

165 | |

166 | BranchProbability operator/(uint32_t RHS) const { |

167 | BranchProbability Prob(*this); |

168 | Prob /= RHS; |

169 | return Prob; |

170 | } |

171 | |

172 | bool operator==(BranchProbability RHS) const { return N == RHS.N; } |

173 | bool operator!=(BranchProbability RHS) const { return !(*this == RHS); } |

174 | |

175 | bool operator<(BranchProbability RHS) const { |

176 | assert(N != UnknownN && RHS.N != UnknownN && |

177 | "Unknown probability cannot participate in comparisons."); |

178 | return N < RHS.N; |

179 | } |

180 | |

181 | bool operator>(BranchProbability RHS) const { |

182 | assert(N != UnknownN && RHS.N != UnknownN && |

183 | "Unknown probability cannot participate in comparisons."); |

184 | return RHS < *this; |

185 | } |

186 | |

187 | bool operator<=(BranchProbability RHS) const { |

188 | assert(N != UnknownN && RHS.N != UnknownN && |

189 | "Unknown probability cannot participate in comparisons."); |

190 | return !(RHS < *this); |

191 | } |

192 | |

193 | bool operator>=(BranchProbability RHS) const { |

194 | assert(N != UnknownN && RHS.N != UnknownN && |

195 | "Unknown probability cannot participate in comparisons."); |

196 | return !(*this < RHS); |

197 | } |

198 | }; |

199 | |

200 | inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) { |

201 | return Prob.print(OS); |

202 | } |

203 | |

204 | template <class ProbabilityIter> |

205 | void BranchProbability::normalizeProbabilities(ProbabilityIter Begin, |

206 | ProbabilityIter End) { |

207 | if (Begin == End) |

208 | return; |

209 | |

210 | unsigned UnknownProbCount = 0; |

211 | uint64_t Sum = std::accumulate(Begin, End, uint64_t(0), |

212 | [&](uint64_t S, const BranchProbability &BP) { |

213 | if (!BP.isUnknown()) |

214 | return S + BP.N; |

215 | UnknownProbCount++; |

216 | return S; |

217 | }); |

218 | |

219 | if (UnknownProbCount > 0) { |

220 | BranchProbability ProbForUnknown = BranchProbability::getZero(); |

221 | // If the sum of all known probabilities is less than one, evenly distribute |

222 | // the complement of sum to unknown probabilities. Otherwise, set unknown |

223 | // probabilities to zeros and continue to normalize known probabilities. |

224 | if (Sum < BranchProbability::getDenominator()) |

225 | ProbForUnknown = BranchProbability::getRaw( |

226 | (BranchProbability::getDenominator() - Sum) / UnknownProbCount); |

227 | |

228 | std::replace_if(Begin, End, |

229 | [](const BranchProbability &BP) { return BP.isUnknown(); }, |

230 | ProbForUnknown); |

231 | |

232 | if (Sum <= BranchProbability::getDenominator()) |

233 | return; |

234 | } |

235 | |

236 | if (Sum == 0) { |

237 | BranchProbability BP(1, std::distance(Begin, End)); |

238 | std::fill(Begin, End, BP); |

239 | return; |

240 | } |

241 | |

242 | for (auto I = Begin; I != End; ++I) |

243 | I->N = (I->N * uint64_t(D) + Sum / 2) / Sum; |

244 | } |

245 | |

246 | } |

247 | |

248 | #endif |

249 |