diff options
Diffstat (limited to 'fs/bcachefs/mean_and_variance.h')
-rw-r--r-- | fs/bcachefs/mean_and_variance.h | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/fs/bcachefs/mean_and_variance.h b/fs/bcachefs/mean_and_variance.h new file mode 100644 index 000000000000..647505010b39 --- /dev/null +++ b/fs/bcachefs/mean_and_variance.h @@ -0,0 +1,198 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef MEAN_AND_VARIANCE_H_ +#define MEAN_AND_VARIANCE_H_ + +#include <linux/types.h> +#include <linux/limits.h> +#include <linux/math.h> +#include <linux/math64.h> + +#define SQRT_U64_MAX 4294967295ULL + +/* + * u128_u: u128 user mode, because not all architectures support a real int128 + * type + */ + +#ifdef __SIZEOF_INT128__ + +typedef struct { + unsigned __int128 v; +} __aligned(16) u128_u; + +static inline u128_u u64_to_u128(u64 a) +{ + return (u128_u) { .v = a }; +} + +static inline u64 u128_lo(u128_u a) +{ + return a.v; +} + +static inline u64 u128_hi(u128_u a) +{ + return a.v >> 64; +} + +static inline u128_u u128_add(u128_u a, u128_u b) +{ + a.v += b.v; + return a; +} + +static inline u128_u u128_sub(u128_u a, u128_u b) +{ + a.v -= b.v; + return a; +} + +static inline u128_u u128_shl(u128_u a, s8 shift) +{ + a.v <<= shift; + return a; +} + +static inline u128_u u128_square(u64 a) +{ + u128_u b = u64_to_u128(a); + + b.v *= b.v; + return b; +} + +#else + +typedef struct { + u64 hi, lo; +} __aligned(16) u128_u; + +/* conversions */ + +static inline u128_u u64_to_u128(u64 a) +{ + return (u128_u) { .lo = a }; +} + +static inline u64 u128_lo(u128_u a) +{ + return a.lo; +} + +static inline u64 u128_hi(u128_u a) +{ + return a.hi; +} + +/* arithmetic */ + +static inline u128_u u128_add(u128_u a, u128_u b) +{ + u128_u c; + + c.lo = a.lo + b.lo; + c.hi = a.hi + b.hi + (c.lo < a.lo); + return c; +} + +static inline u128_u u128_sub(u128_u a, u128_u b) +{ + u128_u c; + + c.lo = a.lo - b.lo; + c.hi = a.hi - b.hi - (c.lo > a.lo); + return c; +} + +static inline u128_u u128_shl(u128_u i, s8 shift) +{ + u128_u r; + + r.lo = i.lo << shift; + if (shift < 64) + r.hi = (i.hi << shift) | (i.lo >> (64 - shift)); + else { + r.hi = i.lo << (shift - 64); + r.lo = 0; + } + return r; +} + +static inline u128_u u128_square(u64 i) +{ + u128_u r; + u64 h = i >> 32, l = i & U32_MAX; + + r = u128_shl(u64_to_u128(h*h), 64); + r = u128_add(r, u128_shl(u64_to_u128(h*l), 32)); + r = u128_add(r, u128_shl(u64_to_u128(l*h), 32)); + r = u128_add(r, u64_to_u128(l*l)); + return r; +} + +#endif + +static inline u128_u u64s_to_u128(u64 hi, u64 lo) +{ + u128_u c = u64_to_u128(hi); + + c = u128_shl(c, 64); + c = u128_add(c, u64_to_u128(lo)); + return c; +} + +u128_u u128_div(u128_u n, u64 d); + +struct mean_and_variance { + s64 n; + s64 sum; + u128_u sum_squares; +}; + +/* expontentially weighted variant */ +struct mean_and_variance_weighted { + bool init; + u8 weight; /* base 2 logarithim */ + s64 mean; + u64 variance; +}; + +/** + * fast_divpow2() - fast approximation for n / (1 << d) + * @n: numerator + * @d: the power of 2 denominator. + * + * note: this rounds towards 0. + */ +static inline s64 fast_divpow2(s64 n, u8 d) +{ + return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; +} + +/** + * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 + * and return it. + * @s1: the mean_and_variance to update. + * @v1: the new sample. + * + * see linked pdf equation 12. + */ +static inline void +mean_and_variance_update(struct mean_and_variance *s, s64 v) +{ + s->n++; + s->sum += v; + s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v))); +} + +s64 mean_and_variance_get_mean(struct mean_and_variance s); +u64 mean_and_variance_get_variance(struct mean_and_variance s1); +u32 mean_and_variance_get_stddev(struct mean_and_variance s); + +void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v); + +s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s); +u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s); +u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s); + +#endif // MEAN_AND_VAIRANCE_H_ |