1/* boost random/detail/polynomial.hpp header file
2 *
3 * Copyright Steven Watanabe 2014
4 * Distributed under the Boost Software License, Version 1.0. (See
5 * accompanying file LICENSE_1_0.txt or copy at
6 * http://www.boost.org/LICENSE_1_0.txt)
7 *
8 * See http://www.boost.org for most recent version including documentation.
9 *
10 * $Id$
11 */
12
13#ifndef BOOST_RANDOM_DETAIL_POLYNOMIAL_HPP
14#define BOOST_RANDOM_DETAIL_POLYNOMIAL_HPP
15
16#include <cstddef>
17#include <limits>
18#include <vector>
19#include <algorithm>
20#include <boost/assert.hpp>
21#include <boost/cstdint.hpp>
22
23namespace boost {
24namespace random {
25namespace detail {
26
27class polynomial_ops {
28public:
29 typedef unsigned long digit_t;
30
31 static void add(std::size_t size, const digit_t * lhs,
32 const digit_t * rhs, digit_t * output)
33 {
34 for(std::size_t i = 0; i < size; ++i) {
35 output[i] = lhs[i] ^ rhs[i];
36 }
37 }
38
39 static void add_shifted_inplace(std::size_t size, const digit_t * lhs,
40 digit_t * output, std::size_t shift)
41 {
42 if(shift == 0) {
43 add(size, lhs, output, output);
44 return;
45 }
46 std::size_t bits = std::numeric_limits<digit_t>::digits;
47 digit_t prev = 0;
48 for(std::size_t i = 0; i < size; ++i) {
49 digit_t tmp = lhs[i];
50 output[i] ^= (tmp << shift) | (prev >> (bits-shift));
51 prev = tmp;
52 }
53 output[size] ^= (prev >> (bits-shift));
54 }
55
56 static void multiply_simple(std::size_t size, const digit_t * lhs,
57 const digit_t * rhs, digit_t * output)
58 {
59 std::size_t bits = std::numeric_limits<digit_t>::digits;
60 for(std::size_t i = 0; i < 2*size; ++i) {
61 output[i] = 0;
62 }
63 for(std::size_t i = 0; i < size; ++i) {
64 for(std::size_t j = 0; j < bits; ++j) {
65 if((lhs[i] & (digit_t(1) << j)) != 0) {
66 add_shifted_inplace(size, rhs, output + i, j);
67 }
68 }
69 }
70 }
71
72 // memory requirements: (size - cutoff) * 4 + next_smaller
73 static void multiply_karatsuba(std::size_t size,
74 const digit_t * lhs, const digit_t * rhs,
75 digit_t * output)
76 {
77 if(size < 64) {
78 multiply_simple(size, lhs, rhs, output);
79 return;
80 }
81 // split in half
82 std::size_t cutoff = size/2;
83 multiply_karatsuba(cutoff, lhs, rhs, output);
84 multiply_karatsuba(size - cutoff, lhs + cutoff, rhs + cutoff,
85 output + cutoff*2);
86 std::vector<digit_t> local1(size - cutoff);
87 std::vector<digit_t> local2(size - cutoff);
88 // combine the digits for the inner multiply
89 add(cutoff, lhs, lhs + cutoff, &local1[0]);
90 if(size & 1) local1[cutoff] = lhs[size - 1];
91 add(cutoff, rhs + cutoff, rhs, &local2[0]);
92 if(size & 1) local2[cutoff] = rhs[size - 1];
93 std::vector<digit_t> local3((size - cutoff) * 2);
94 multiply_karatsuba(size - cutoff, &local1[0], &local2[0], &local3[0]);
95 add(cutoff * 2, output, &local3[0], &local3[0]);
96 add((size - cutoff) * 2, output + cutoff*2, &local3[0], &local3[0]);
97 // Finally, add the inner result
98 add((size - cutoff) * 2, output + cutoff, &local3[0], output + cutoff);
99 }
100
101 static void multiply_add_karatsuba(std::size_t size,
102 const digit_t * lhs, const digit_t * rhs,
103 digit_t * output)
104 {
105 std::vector<digit_t> buf(size * 2);
106 multiply_karatsuba(size, lhs, rhs, &buf[0]);
107 add(size * 2, &buf[0], output, output);
108 }
109
110 static void multiply(const digit_t * lhs, std::size_t lhs_size,
111 const digit_t * rhs, std::size_t rhs_size,
112 digit_t * output)
113 {
114 std::fill_n(output, lhs_size + rhs_size, digit_t(0));
115 multiply_add(lhs, lhs_size, rhs, rhs_size, output);
116 }
117
118 static void multiply_add(const digit_t * lhs, std::size_t lhs_size,
119 const digit_t * rhs, std::size_t rhs_size,
120 digit_t * output)
121 {
122 // split into pieces that can be passed to
123 // karatsuba multiply.
124 while(lhs_size != 0) {
125 if(lhs_size < rhs_size) {
126 std::swap(lhs, rhs);
127 std::swap(lhs_size, rhs_size);
128 }
129
130 multiply_add_karatsuba(rhs_size, lhs, rhs, output);
131
132 lhs += rhs_size;
133 lhs_size -= rhs_size;
134 output += rhs_size;
135 }
136 }
137
138 static void copy_bits(const digit_t * x, std::size_t low, std::size_t high,
139 digit_t * out)
140 {
141 const std::size_t bits = std::numeric_limits<digit_t>::digits;
142 std::size_t offset = low/bits;
143 x += offset;
144 low -= offset*bits;
145 high -= offset*bits;
146 std::size_t n = (high-low)/bits;
147 if(low == 0) {
148 for(std::size_t i = 0; i < n; ++i) {
149 out[i] = x[i];
150 }
151 } else {
152 for(std::size_t i = 0; i < n; ++i) {
153 out[i] = (x[i] >> low) | (x[i+1] << (bits-low));
154 }
155 }
156 if((high-low)%bits) {
157 digit_t low_mask = (digit_t(1) << ((high-low)%bits)) - 1;
158 digit_t result = (x[n] >> low);
159 if(low != 0 && (n+1)*bits < high) {
160 result |= (x[n+1] << (bits-low));
161 }
162 out[n] = (result & low_mask);
163 }
164 }
165
166 static void shift_left(digit_t * val, std::size_t size, std::size_t shift)
167 {
168 const std::size_t bits = std::numeric_limits<digit_t>::digits;
169 BOOST_ASSERT(shift > 0);
170 BOOST_ASSERT(shift < bits);
171 digit_t prev = 0;
172 for(std::size_t i = 0; i < size; ++i) {
173 digit_t tmp = val[i];
174 val[i] = (prev >> (bits - shift)) | (val[i] << shift);
175 prev = tmp;
176 }
177 }
178
179 static digit_t sqr(digit_t val) {
180 const std::size_t bits = std::numeric_limits<digit_t>::digits;
181 digit_t mask = (digit_t(1) << bits/2) - 1;
182 for(std::size_t i = bits; i > 1; i /= 2) {
183 val = ((val & ~mask) << i/2) | (val & mask);
184 mask = mask & (mask >> i/4);
185 mask = mask | (mask << i/2);
186 }
187 return val;
188 }
189
190 static void sqr(digit_t * val, std::size_t size)
191 {
192 const std::size_t bits = std::numeric_limits<digit_t>::digits;
193 digit_t mask = (digit_t(1) << bits/2) - 1;
194 for(std::size_t i = 0; i < size; ++i) {
195 digit_t x = val[size - i - 1];
196 val[(size - i - 1) * 2] = sqr(x & mask);
197 val[(size - i - 1) * 2 + 1] = sqr(x >> bits/2);
198 }
199 }
200
201 // optimized for the case when the modulus has few bits set.
202 struct sparse_mod {
203 sparse_mod(const digit_t * divisor, std::size_t divisor_bits)
204 {
205 const std::size_t bits = std::numeric_limits<digit_t>::digits;
206 _remainder_bits = divisor_bits - 1;
207 for(std::size_t i = 0; i < divisor_bits; ++i) {
208 if(divisor[i/bits] & (digit_t(1) << i%bits)) {
209 _bit_indices.push_back(i);
210 }
211 }
212 BOOST_ASSERT(_bit_indices.back() == divisor_bits - 1);
213 _bit_indices.pop_back();
214 if(_bit_indices.empty()) {
215 _block_bits = divisor_bits;
216 _lower_bits = 0;
217 } else {
218 _block_bits = divisor_bits - _bit_indices.back() - 1;
219 _lower_bits = _bit_indices.back() + 1;
220 }
221
222 _partial_quotient.resize((_block_bits + bits - 1)/bits);
223 }
224 void operator()(digit_t * dividend, std::size_t dividend_bits)
225 {
226 const std::size_t bits = std::numeric_limits<digit_t>::digits;
227 while(dividend_bits > _remainder_bits) {
228 std::size_t block_start = (std::max)(dividend_bits - _block_bits, _remainder_bits);
229 std::size_t block_size = (dividend_bits - block_start + bits - 1) / bits;
230 copy_bits(dividend, block_start, dividend_bits, &_partial_quotient[0]);
231 for(std::size_t i = 0; i < _bit_indices.size(); ++i) {
232 std::size_t pos = _bit_indices[i] + block_start - _remainder_bits;
233 add_shifted_inplace(block_size, &_partial_quotient[0], dividend + pos/bits, pos%bits);
234 }
235 add_shifted_inplace(block_size, &_partial_quotient[0], dividend + block_start/bits, block_start%bits);
236 dividend_bits = block_start;
237 }
238 }
239 std::vector<digit_t> _partial_quotient;
240 std::size_t _remainder_bits;
241 std::size_t _block_bits;
242 std::size_t _lower_bits;
243 std::vector<std::size_t> _bit_indices;
244 };
245
246 // base should have the same number of bits as mod
247 // base, and mod should both be able to hold a power
248 // of 2 >= mod_bits. out needs to be twice as large.
249 static void mod_pow_x(boost::uintmax_t exponent, const digit_t * mod, std::size_t mod_bits, digit_t * out)
250 {
251 const std::size_t bits = std::numeric_limits<digit_t>::digits;
252 const std::size_t n = (mod_bits + bits - 1) / bits;
253 const std::size_t highbit = mod_bits - 1;
254 if(exponent == 0) {
255 out[0] = 1;
256 std::fill_n(out + 1, n - 1, digit_t(0));
257 return;
258 }
259 boost::uintmax_t i = std::numeric_limits<boost::uintmax_t>::digits - 1;
260 while(((boost::uintmax_t(1) << i) & exponent) == 0) {
261 --i;
262 }
263 out[0] = 2;
264 std::fill_n(out + 1, n - 1, digit_t(0));
265 sparse_mod m(mod, mod_bits);
266 while(i--) {
267 sqr(out, n);
268 m(out, 2 * mod_bits - 1);
269 if((boost::uintmax_t(1) << i) & exponent) {
270 shift_left(out, n, 1);
271 if(out[highbit / bits] & (digit_t(1) << highbit%bits))
272 add(n, out, mod, out);
273 }
274 }
275 }
276};
277
278class polynomial
279{
280 typedef polynomial_ops::digit_t digit_t;
281public:
282 polynomial() : _size(0) {}
283 class reference {
284 public:
285 reference(digit_t &value, int idx)
286 : _value(value), _idx(idx) {}
287 operator bool() const { return (_value & (digit_t(1) << _idx)) != 0; }
288 reference& operator=(bool b)
289 {
290 if(b) {
291 _value |= (digit_t(1) << _idx);
292 } else {
293 _value &= ~(digit_t(1) << _idx);
294 }
295 return *this;
296 }
297 reference &operator^=(bool b)
298 {
299 _value ^= (digit_t(b) << _idx);
300 return *this;
301 }
302
303 reference &operator=(const reference &other)
304 {
305 return *this = static_cast<bool>(other);
306 }
307 private:
308 digit_t &_value;
309 int _idx;
310 };
311 reference operator[](std::size_t i)
312 {
313 static const std::size_t bits = std::numeric_limits<digit_t>::digits;
314 ensure_bit(i);
315 return reference(_storage[i/bits], i%bits);
316 }
317 bool operator[](std::size_t i) const
318 {
319 static const std::size_t bits = std::numeric_limits<digit_t>::digits;
320 if(i < size())
321 return (_storage[i/bits] & (digit_t(1) << (i%bits))) != 0;
322 else
323 return false;
324 }
325 std::size_t size() const
326 {
327 return _size;
328 }
329 void resize(std::size_t n)
330 {
331 static const std::size_t bits = std::numeric_limits<digit_t>::digits;
332 _storage.resize((n + bits - 1)/bits);
333 // clear the high order bits in case we're shrinking.
334 if(n%bits) {
335 _storage.back() &= ((digit_t(1) << (n%bits)) - 1);
336 }
337 _size = n;
338 }
339 friend polynomial operator*(const polynomial &lhs, const polynomial &rhs);
340 friend polynomial mod_pow_x(boost::uintmax_t exponent, polynomial mod);
341private:
342 std::vector<polynomial_ops::digit_t> _storage;
343 std::size_t _size;
344 void ensure_bit(std::size_t i)
345 {
346 if(i >= size()) {
347 resize(i + 1);
348 }
349 }
350 void normalize()
351 {
352 while(size() && (*this)[size() - 1] == 0)
353 resize(size() - 1);
354 }
355};
356
357inline polynomial operator*(const polynomial &lhs, const polynomial &rhs)
358{
359 polynomial result;
360 result._storage.resize(lhs._storage.size() + rhs._storage.size());
361 polynomial_ops::multiply(&lhs._storage[0], lhs._storage.size(),
362 &rhs._storage[0], rhs._storage.size(),
363 &result._storage[0]);
364 result._size = lhs._size + rhs._size;
365 return result;
366}
367
368inline polynomial mod_pow_x(boost::uintmax_t exponent, polynomial mod)
369{
370 polynomial result;
371 mod.normalize();
372 std::size_t mod_size = mod.size();
373 result._storage.resize(mod._storage.size() * 2);
374 result._size = mod.size() * 2;
375 polynomial_ops::mod_pow_x(exponent, &mod._storage[0], mod_size, &result._storage[0]);
376 result.resize(mod.size() - 1);
377 return result;
378}
379
380}
381}
382}
383
384#endif // BOOST_RANDOM_DETAIL_POLYNOMIAL_HPP
385