1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* |
3 | * Functions for incremental mean and variance. |
4 | * |
5 | * This program is free software; you can redistribute it and/or modify it |
6 | * under the terms of the GNU General Public License version 2 as published by |
7 | * the Free Software Foundation. |
8 | * |
9 | * This program is distributed in the hope that it will be useful, but WITHOUT |
10 | * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or |
11 | * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for |
12 | * more details. |
13 | * |
14 | * Copyright © 2022 Daniel B. Hill |
15 | * |
16 | * Author: Daniel B. Hill <daniel@gluo.nz> |
17 | * |
18 | * Description: |
19 | * |
20 | * This is includes some incremental algorithms for mean and variance calculation |
21 | * |
22 | * Derived from the paper: https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf |
23 | * |
24 | * Create a struct and if it's the weighted variant set the w field (weight = 2^k). |
25 | * |
26 | * Use mean_and_variance[_weighted]_update() on the struct to update it's state. |
27 | * |
28 | * Use the mean_and_variance[_weighted]_get_* functions to calculate the mean and variance, some computation |
29 | * is deferred to these functions for performance reasons. |
30 | * |
31 | * see lib/math/mean_and_variance_test.c for examples of usage. |
32 | * |
33 | * DO NOT access the mean and variance fields of the weighted variants directly. |
34 | * DO NOT change the weight after calling update. |
35 | */ |
36 | |
37 | #include <linux/bug.h> |
38 | #include <linux/compiler.h> |
39 | #include <linux/export.h> |
40 | #include <linux/limits.h> |
41 | #include <linux/math.h> |
42 | #include <linux/math64.h> |
43 | #include <linux/module.h> |
44 | |
45 | #include "mean_and_variance.h" |
46 | |
47 | u128_u u128_div(u128_u n, u64 d) |
48 | { |
49 | u128_u r; |
50 | u64 rem; |
51 | u64 hi = u128_hi(a: n); |
52 | u64 lo = u128_lo(a: n); |
53 | u64 h = hi & ((u64) U32_MAX << 32); |
54 | u64 l = (hi & (u64) U32_MAX) << 32; |
55 | |
56 | r = u128_shl(a: u64_to_u128(a: div64_u64_rem(dividend: h, divisor: d, remainder: &rem)), shift: 64); |
57 | r = u128_add(a: r, b: u128_shl(a: u64_to_u128(a: div64_u64_rem(dividend: l + (rem << 32), divisor: d, remainder: &rem)), shift: 32)); |
58 | r = u128_add(a: r, b: u64_to_u128(a: div64_u64_rem(dividend: lo + (rem << 32), divisor: d, remainder: &rem))); |
59 | return r; |
60 | } |
61 | EXPORT_SYMBOL_GPL(u128_div); |
62 | |
63 | /** |
64 | * mean_and_variance_get_mean() - get mean from @s |
65 | * @s: mean and variance number of samples and their sums |
66 | */ |
67 | s64 mean_and_variance_get_mean(struct mean_and_variance s) |
68 | { |
69 | return s.n ? div64_u64(dividend: s.sum, divisor: s.n) : 0; |
70 | } |
71 | EXPORT_SYMBOL_GPL(mean_and_variance_get_mean); |
72 | |
73 | /** |
74 | * mean_and_variance_get_variance() - get variance from @s1 |
75 | * @s1: mean and variance number of samples and sums |
76 | * |
77 | * see linked pdf equation 12. |
78 | */ |
79 | u64 mean_and_variance_get_variance(struct mean_and_variance s1) |
80 | { |
81 | if (s1.n) { |
82 | u128_u s2 = u128_div(s1.sum_squares, s1.n); |
83 | u64 s3 = abs(mean_and_variance_get_mean(s1)); |
84 | |
85 | return u128_lo(a: u128_sub(a: s2, b: u128_square(a: s3))); |
86 | } else { |
87 | return 0; |
88 | } |
89 | } |
90 | EXPORT_SYMBOL_GPL(mean_and_variance_get_variance); |
91 | |
92 | /** |
93 | * mean_and_variance_get_stddev() - get standard deviation from @s |
94 | * @s: mean and variance number of samples and their sums |
95 | */ |
96 | u32 mean_and_variance_get_stddev(struct mean_and_variance s) |
97 | { |
98 | return int_sqrt64(x: mean_and_variance_get_variance(s)); |
99 | } |
100 | EXPORT_SYMBOL_GPL(mean_and_variance_get_stddev); |
101 | |
102 | /** |
103 | * mean_and_variance_weighted_update() - exponentially weighted variant of mean_and_variance_update() |
104 | * @s: mean and variance number of samples and their sums |
105 | * @x: new value to include in the &mean_and_variance_weighted |
106 | * @initted: caller must track whether this is the first use or not |
107 | * @weight: ewma weight |
108 | * |
109 | * see linked pdf: function derived from equations 140-143 where alpha = 2^w. |
110 | * values are stored bitshifted for performance and added precision. |
111 | */ |
112 | void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, |
113 | s64 x, bool initted, u8 weight) |
114 | { |
115 | // previous weighted variance. |
116 | u8 w = weight; |
117 | u64 var_w0 = s->variance; |
118 | // new value weighted. |
119 | s64 x_w = x << w; |
120 | s64 diff_w = x_w - s->mean; |
121 | s64 diff = fast_divpow2(n: diff_w, d: w); |
122 | // new mean weighted. |
123 | s64 u_w1 = s->mean + diff; |
124 | |
125 | if (!initted) { |
126 | s->mean = x_w; |
127 | s->variance = 0; |
128 | } else { |
129 | s->mean = u_w1; |
130 | s->variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w; |
131 | } |
132 | } |
133 | EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update); |
134 | |
135 | /** |
136 | * mean_and_variance_weighted_get_mean() - get mean from @s |
137 | * @s: mean and variance number of samples and their sums |
138 | * @weight: ewma weight |
139 | */ |
140 | s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s, |
141 | u8 weight) |
142 | { |
143 | return fast_divpow2(n: s.mean, d: weight); |
144 | } |
145 | EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean); |
146 | |
147 | /** |
148 | * mean_and_variance_weighted_get_variance() -- get variance from @s |
149 | * @s: mean and variance number of samples and their sums |
150 | * @weight: ewma weight |
151 | */ |
152 | u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s, |
153 | u8 weight) |
154 | { |
155 | // always positive don't need fast divpow2 |
156 | return s.variance >> weight; |
157 | } |
158 | EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_variance); |
159 | |
160 | /** |
161 | * mean_and_variance_weighted_get_stddev() - get standard deviation from @s |
162 | * @s: mean and variance number of samples and their sums |
163 | * @weight: ewma weight |
164 | */ |
165 | u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s, |
166 | u8 weight) |
167 | { |
168 | return int_sqrt64(x: mean_and_variance_weighted_get_variance(s, weight)); |
169 | } |
170 | EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_stddev); |
171 | |
172 | MODULE_AUTHOR("Daniel B. Hill" ); |
173 | MODULE_LICENSE("GPL" ); |
174 | |