1 | // SPDX-License-Identifier: GPL-2.0 |
2 | #include <kunit/test.h> |
3 | |
4 | #include "mean_and_variance.h" |
5 | |
6 | #define MAX_SQR (SQRT_U64_MAX*SQRT_U64_MAX) |
7 | |
8 | static void mean_and_variance_basic_test(struct kunit *test) |
9 | { |
10 | struct mean_and_variance s = {}; |
11 | |
12 | mean_and_variance_update(s: &s, v: 2); |
13 | mean_and_variance_update(s: &s, v: 2); |
14 | |
15 | KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 2); |
16 | KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 0); |
17 | KUNIT_EXPECT_EQ(test, s.n, 2); |
18 | |
19 | mean_and_variance_update(s: &s, v: 4); |
20 | mean_and_variance_update(s: &s, v: 4); |
21 | |
22 | KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 3); |
23 | KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 1); |
24 | KUNIT_EXPECT_EQ(test, s.n, 4); |
25 | } |
26 | |
27 | /* |
28 | * Test values computed using a spreadsheet from the psuedocode at the bottom: |
29 | * https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf |
30 | */ |
31 | |
32 | static void mean_and_variance_weighted_test(struct kunit *test) |
33 | { |
34 | struct mean_and_variance_weighted s = { }; |
35 | |
36 | mean_and_variance_weighted_update(s: &s, v: 10, initted: false, weight: 2); |
37 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), 10); |
38 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 0); |
39 | |
40 | mean_and_variance_weighted_update(s: &s, v: 20, initted: true, weight: 2); |
41 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), 12); |
42 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 18); |
43 | |
44 | mean_and_variance_weighted_update(s: &s, v: 30, initted: true, weight: 2); |
45 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), 16); |
46 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 72); |
47 | |
48 | s = (struct mean_and_variance_weighted) { }; |
49 | |
50 | mean_and_variance_weighted_update(s: &s, v: -10, initted: false, weight: 2); |
51 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), -10); |
52 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 0); |
53 | |
54 | mean_and_variance_weighted_update(s: &s, v: -20, initted: true, weight: 2); |
55 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), -12); |
56 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 18); |
57 | |
58 | mean_and_variance_weighted_update(s: &s, v: -30, initted: true, weight: 2); |
59 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), -16); |
60 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 72); |
61 | } |
62 | |
63 | static void mean_and_variance_weighted_advanced_test(struct kunit *test) |
64 | { |
65 | struct mean_and_variance_weighted s = { }; |
66 | bool initted = false; |
67 | s64 i; |
68 | |
69 | for (i = 10; i <= 100; i += 10) { |
70 | mean_and_variance_weighted_update(s: &s, v: i, initted, weight: 8); |
71 | initted = true; |
72 | } |
73 | |
74 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 8), 11); |
75 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 8), 107); |
76 | |
77 | s = (struct mean_and_variance_weighted) { }; |
78 | initted = false; |
79 | |
80 | for (i = -10; i >= -100; i -= 10) { |
81 | mean_and_variance_weighted_update(s: &s, v: i, initted, weight: 8); |
82 | initted = true; |
83 | } |
84 | |
85 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 8), -11); |
86 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 8), 107); |
87 | } |
88 | |
89 | static void do_mean_and_variance_test(struct kunit *test, |
90 | s64 initial_value, |
91 | s64 initial_n, |
92 | s64 n, |
93 | unsigned weight, |
94 | s64 *data, |
95 | s64 *mean, |
96 | s64 *stddev, |
97 | s64 *weighted_mean, |
98 | s64 *weighted_stddev) |
99 | { |
100 | struct mean_and_variance mv = {}; |
101 | struct mean_and_variance_weighted vw = { }; |
102 | |
103 | for (unsigned i = 0; i < initial_n; i++) { |
104 | mean_and_variance_update(s: &mv, v: initial_value); |
105 | mean_and_variance_weighted_update(s: &vw, v: initial_value, initted: false, weight); |
106 | |
107 | KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(mv), initial_value); |
108 | KUNIT_EXPECT_EQ(test, mean_and_variance_get_stddev(mv), 0); |
109 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(vw, weight), initial_value); |
110 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_stddev(vw, weight),0); |
111 | } |
112 | |
113 | for (unsigned i = 0; i < n; i++) { |
114 | mean_and_variance_update(s: &mv, v: data[i]); |
115 | mean_and_variance_weighted_update(s: &vw, v: data[i], initted: true, weight); |
116 | |
117 | KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(mv), mean[i]); |
118 | KUNIT_EXPECT_EQ(test, mean_and_variance_get_stddev(mv), stddev[i]); |
119 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(vw, weight), weighted_mean[i]); |
120 | KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_stddev(vw, weight),weighted_stddev[i]); |
121 | } |
122 | |
123 | KUNIT_EXPECT_EQ(test, mv.n, initial_n + n); |
124 | } |
125 | |
126 | /* Test behaviour with a single outlier, then back to steady state: */ |
127 | static void mean_and_variance_test_1(struct kunit *test) |
128 | { |
129 | s64 d[] = { 100, 10, 10, 10, 10, 10, 10 }; |
130 | s64 mean[] = { 22, 21, 20, 19, 18, 17, 16 }; |
131 | s64 stddev[] = { 32, 29, 28, 27, 26, 25, 24 }; |
132 | s64 weighted_mean[] = { 32, 27, 22, 19, 17, 15, 14 }; |
133 | s64 weighted_stddev[] = { 38, 35, 31, 27, 24, 21, 18 }; |
134 | |
135 | do_mean_and_variance_test(test, initial_value: 10, initial_n: 6, ARRAY_SIZE(d), weight: 2, |
136 | data: d, mean, stddev, weighted_mean, weighted_stddev); |
137 | } |
138 | |
139 | /* Test behaviour where we switch from one steady state to another: */ |
140 | static void mean_and_variance_test_2(struct kunit *test) |
141 | { |
142 | s64 d[] = { 100, 100, 100, 100, 100 }; |
143 | s64 mean[] = { 22, 32, 40, 46, 50 }; |
144 | s64 stddev[] = { 32, 39, 42, 44, 45 }; |
145 | s64 weighted_mean[] = { 32, 49, 61, 71, 78 }; |
146 | s64 weighted_stddev[] = { 38, 44, 44, 41, 38 }; |
147 | |
148 | do_mean_and_variance_test(test, initial_value: 10, initial_n: 6, ARRAY_SIZE(d), weight: 2, |
149 | data: d, mean, stddev, weighted_mean, weighted_stddev); |
150 | } |
151 | |
152 | static void mean_and_variance_fast_divpow2(struct kunit *test) |
153 | { |
154 | s64 i; |
155 | u8 d; |
156 | |
157 | for (i = 0; i < 100; i++) { |
158 | d = 0; |
159 | KUNIT_EXPECT_EQ(test, fast_divpow2(i, d), div_u64(i, 1LLU << d)); |
160 | KUNIT_EXPECT_EQ(test, abs(fast_divpow2(-i, d)), div_u64(i, 1LLU << d)); |
161 | for (d = 1; d < 32; d++) { |
162 | KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(i, d)), |
163 | div_u64(i, 1 << d), "%lld %u" , i, d); |
164 | KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(-i, d)), |
165 | div_u64(i, 1 << d), "%lld %u" , -i, d); |
166 | } |
167 | } |
168 | } |
169 | |
170 | static void mean_and_variance_u128_basic_test(struct kunit *test) |
171 | { |
172 | u128_u a = u64s_to_u128(hi: 0, U64_MAX); |
173 | u128_u a1 = u64s_to_u128(hi: 0, lo: 1); |
174 | u128_u b = u64s_to_u128(hi: 1, lo: 0); |
175 | u128_u c = u64s_to_u128(hi: 0, lo: 1LLU << 63); |
176 | u128_u c2 = u64s_to_u128(U64_MAX, U64_MAX); |
177 | |
178 | KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a, a1)), 1); |
179 | KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a, a1)), 0); |
180 | KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a1, a)), 1); |
181 | KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a1, a)), 0); |
182 | |
183 | KUNIT_EXPECT_EQ(test, u128_lo(u128_sub(b, a1)), U64_MAX); |
184 | KUNIT_EXPECT_EQ(test, u128_hi(u128_sub(b, a1)), 0); |
185 | |
186 | KUNIT_EXPECT_EQ(test, u128_hi(u128_shl(c, 1)), 1); |
187 | KUNIT_EXPECT_EQ(test, u128_lo(u128_shl(c, 1)), 0); |
188 | |
189 | KUNIT_EXPECT_EQ(test, u128_hi(u128_square(U64_MAX)), U64_MAX - 1); |
190 | KUNIT_EXPECT_EQ(test, u128_lo(u128_square(U64_MAX)), 1); |
191 | |
192 | KUNIT_EXPECT_EQ(test, u128_lo(u128_div(b, 2)), 1LLU << 63); |
193 | |
194 | KUNIT_EXPECT_EQ(test, u128_hi(u128_div(c2, 2)), U64_MAX >> 1); |
195 | KUNIT_EXPECT_EQ(test, u128_lo(u128_div(c2, 2)), U64_MAX); |
196 | |
197 | KUNIT_EXPECT_EQ(test, u128_hi(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U32_MAX >> 1); |
198 | KUNIT_EXPECT_EQ(test, u128_lo(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U64_MAX << 31); |
199 | } |
200 | |
201 | static struct kunit_case mean_and_variance_test_cases[] = { |
202 | KUNIT_CASE(mean_and_variance_fast_divpow2), |
203 | KUNIT_CASE(mean_and_variance_u128_basic_test), |
204 | KUNIT_CASE(mean_and_variance_basic_test), |
205 | KUNIT_CASE(mean_and_variance_weighted_test), |
206 | KUNIT_CASE(mean_and_variance_weighted_advanced_test), |
207 | KUNIT_CASE(mean_and_variance_test_1), |
208 | KUNIT_CASE(mean_and_variance_test_2), |
209 | {} |
210 | }; |
211 | |
212 | static struct kunit_suite mean_and_variance_test_suite = { |
213 | .name = "mean and variance tests" , |
214 | .test_cases = mean_and_variance_test_cases |
215 | }; |
216 | |
217 | kunit_test_suite(mean_and_variance_test_suite); |
218 | |
219 | MODULE_AUTHOR("Daniel B. Hill" ); |
220 | MODULE_LICENSE("GPL" ); |
221 | |