1 | /* SPDX-License-Identifier: GPL-2.0 */ |
2 | #ifndef MEAN_AND_VARIANCE_H_ |
3 | #define MEAN_AND_VARIANCE_H_ |
4 | |
5 | #include <linux/types.h> |
6 | #include <linux/limits.h> |
7 | #include <linux/math.h> |
8 | #include <linux/math64.h> |
9 | |
10 | #define SQRT_U64_MAX 4294967295ULL |
11 | |
12 | /* |
13 | * u128_u: u128 user mode, because not all architectures support a real int128 |
14 | * type |
15 | * |
16 | * We don't use this version in userspace, because in userspace we link with |
17 | * Rust and rustc has issues with u128. |
18 | */ |
19 | |
20 | #if defined(__SIZEOF_INT128__) && defined(__KERNEL__) && !defined(CONFIG_PARISC) |
21 | |
22 | typedef struct { |
23 | unsigned __int128 v; |
24 | } __aligned(16) u128_u; |
25 | |
26 | static inline u128_u u64_to_u128(u64 a) |
27 | { |
28 | return (u128_u) { .v = a }; |
29 | } |
30 | |
31 | static inline u64 u128_lo(u128_u a) |
32 | { |
33 | return a.v; |
34 | } |
35 | |
36 | static inline u64 u128_hi(u128_u a) |
37 | { |
38 | return a.v >> 64; |
39 | } |
40 | |
41 | static inline u128_u u128_add(u128_u a, u128_u b) |
42 | { |
43 | a.v += b.v; |
44 | return a; |
45 | } |
46 | |
47 | static inline u128_u u128_sub(u128_u a, u128_u b) |
48 | { |
49 | a.v -= b.v; |
50 | return a; |
51 | } |
52 | |
53 | static inline u128_u u128_shl(u128_u a, s8 shift) |
54 | { |
55 | a.v <<= shift; |
56 | return a; |
57 | } |
58 | |
59 | static inline u128_u u128_square(u64 a) |
60 | { |
61 | u128_u b = u64_to_u128(a); |
62 | |
63 | b.v *= b.v; |
64 | return b; |
65 | } |
66 | |
67 | #else |
68 | |
69 | typedef struct { |
70 | u64 hi, lo; |
71 | } __aligned(16) u128_u; |
72 | |
73 | /* conversions */ |
74 | |
75 | static inline u128_u u64_to_u128(u64 a) |
76 | { |
77 | return (u128_u) { .lo = a }; |
78 | } |
79 | |
80 | static inline u64 u128_lo(u128_u a) |
81 | { |
82 | return a.lo; |
83 | } |
84 | |
85 | static inline u64 u128_hi(u128_u a) |
86 | { |
87 | return a.hi; |
88 | } |
89 | |
90 | /* arithmetic */ |
91 | |
92 | static inline u128_u u128_add(u128_u a, u128_u b) |
93 | { |
94 | u128_u c; |
95 | |
96 | c.lo = a.lo + b.lo; |
97 | c.hi = a.hi + b.hi + (c.lo < a.lo); |
98 | return c; |
99 | } |
100 | |
101 | static inline u128_u u128_sub(u128_u a, u128_u b) |
102 | { |
103 | u128_u c; |
104 | |
105 | c.lo = a.lo - b.lo; |
106 | c.hi = a.hi - b.hi - (c.lo > a.lo); |
107 | return c; |
108 | } |
109 | |
110 | static inline u128_u u128_shl(u128_u i, s8 shift) |
111 | { |
112 | u128_u r; |
113 | |
114 | r.lo = i.lo << shift; |
115 | if (shift < 64) |
116 | r.hi = (i.hi << shift) | (i.lo >> (64 - shift)); |
117 | else { |
118 | r.hi = i.lo << (shift - 64); |
119 | r.lo = 0; |
120 | } |
121 | return r; |
122 | } |
123 | |
124 | static inline u128_u u128_square(u64 i) |
125 | { |
126 | u128_u r; |
127 | u64 h = i >> 32, l = i & U32_MAX; |
128 | |
129 | r = u128_shl(u64_to_u128(h*h), 64); |
130 | r = u128_add(r, u128_shl(u64_to_u128(h*l), 32)); |
131 | r = u128_add(r, u128_shl(u64_to_u128(l*h), 32)); |
132 | r = u128_add(r, u64_to_u128(l*l)); |
133 | return r; |
134 | } |
135 | |
136 | #endif |
137 | |
138 | static inline u128_u u64s_to_u128(u64 hi, u64 lo) |
139 | { |
140 | u128_u c = u64_to_u128(a: hi); |
141 | |
142 | c = u128_shl(a: c, shift: 64); |
143 | c = u128_add(a: c, b: u64_to_u128(a: lo)); |
144 | return c; |
145 | } |
146 | |
147 | u128_u u128_div(u128_u n, u64 d); |
148 | |
149 | struct mean_and_variance { |
150 | s64 n; |
151 | s64 sum; |
152 | u128_u sum_squares; |
153 | }; |
154 | |
155 | /* expontentially weighted variant */ |
156 | struct mean_and_variance_weighted { |
157 | s64 mean; |
158 | u64 variance; |
159 | }; |
160 | |
161 | /** |
162 | * fast_divpow2() - fast approximation for n / (1 << d) |
163 | * @n: numerator |
164 | * @d: the power of 2 denominator. |
165 | * |
166 | * note: this rounds towards 0. |
167 | */ |
168 | static inline s64 fast_divpow2(s64 n, u8 d) |
169 | { |
170 | return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; |
171 | } |
172 | |
173 | /** |
174 | * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 |
175 | * and return it. |
176 | * @s1: the mean_and_variance to update. |
177 | * @v1: the new sample. |
178 | * |
179 | * see linked pdf equation 12. |
180 | */ |
181 | static inline void |
182 | mean_and_variance_update(struct mean_and_variance *s, s64 v) |
183 | { |
184 | s->n++; |
185 | s->sum += v; |
186 | s->sum_squares = u128_add(a: s->sum_squares, b: u128_square(abs(v))); |
187 | } |
188 | |
189 | s64 mean_and_variance_get_mean(struct mean_and_variance s); |
190 | u64 mean_and_variance_get_variance(struct mean_and_variance s1); |
191 | u32 mean_and_variance_get_stddev(struct mean_and_variance s); |
192 | |
193 | void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, |
194 | s64 v, bool initted, u8 weight); |
195 | |
196 | s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s, |
197 | u8 weight); |
198 | u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s, |
199 | u8 weight); |
200 | u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s, |
201 | u8 weight); |
202 | |
203 | #endif // MEAN_AND_VAIRANCE_H_ |
204 | |