Skip to content

Commit

Permalink
Merge pull request #76 from byeongkeunahn/linear-recurrence
Browse files Browse the repository at this point in the history
Add linear recurrence solver (currently Kitamasa)
  • Loading branch information
kiwiyou authored Feb 6, 2024
2 parents 5c79333 + be26bad commit 7de7ebe
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 6 deletions.
67 changes: 61 additions & 6 deletions basm-std/src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ pub trait ModOps<T>:
fn one() -> T;
fn two() -> T;
fn my_wrapping_sub(&self, other: T) -> T;
fn modadd(x: T, y: T, modulo: T) -> T;
fn modsub(x: T, y: T, modulo: T) -> T;
fn modinv(&self, modulo: T) -> Option<T>;
fn modmul(x: T, y: T, modulo: T) -> T;
}
Expand All @@ -146,6 +148,17 @@ macro_rules! impl_mod_ops_signed {
None
}
}
fn modadd(x: $t, y: $t, modulo: $t) -> $t {
Self::modsub(x, Self::modsub(0, y, modulo), modulo)
}
fn modsub(x: $t, y: $t, modulo: $t) -> $t {
debug_assert!(modulo > 0);
let (mut x, mut y) = (x % modulo, y % modulo);
if x < 0 { x += modulo; }
if y < 0 { y += modulo; }
let out = x - y;
if out < 0 { out + modulo } else { out }
}
fn modmul(x: $t, y: $t, modulo: $t) -> $t {
debug_assert!(modulo > 0);
if <$t>::BITS <= 16 {
Expand Down Expand Up @@ -176,15 +189,19 @@ macro_rules! impl_mod_ops_unsigned {
fn one() -> $t { 1 }
fn two() -> $t { 2 }
fn my_wrapping_sub(&self, other: $t) -> $t { self.wrapping_sub(other) }
fn modadd(x: $t, y: $t, modulo: $t) -> $t {
Self::modsub(x, Self::modsub(0, y, modulo), modulo)
}
fn modsub(x: $t, y: $t, modulo: $t) -> $t {
debug_assert!(modulo > 0);
let (x, y) = (x % modulo, y % modulo);
let (out, overflow) = x.overflowing_sub(y);
if overflow { out.wrapping_add(modulo) } else { out }
}
fn modinv(&self, modulo: $t) -> Option<$t> {
if modulo <= 1 {
return None;
}
fn modsub(x: $t, mut y: $t, modulo: $t) -> $t {
y %= modulo;
let (out, overflow) = x.overflowing_sub(y);
if overflow { out.wrapping_add(modulo) } else { out }
}
let (mut a, mut b) = (*self, modulo);
let mut c: [$t; 4] = if a > b {
(a, b) = (b, a);
Expand Down Expand Up @@ -225,7 +242,21 @@ macro_rules! impl_mod_ops_unsigned {
}
impl_mod_ops_unsigned!(u8, u16, u32, u64, u128, usize);

/// Computes the modular multiplication of `x` and `y`.
/// Computes the modular addition `x + y`.
///
/// This function will panic if `modulo` is zero or negative.
pub fn modadd<T: ModOps<T>>(x: T, y: T, modulo: T) -> T {
T::modadd(x, y, modulo)
}

/// Computes the modular subtraction `x - y`.
///
/// This function will panic if `modulo` is zero or negative.
pub fn modsub<T: ModOps<T>>(x: T, y: T, modulo: T) -> T {
T::modsub(x, y, modulo)
}

/// Computes the modular multiplication `x * y`.
///
/// This function will panic if `modulo` is zero or negative.
pub fn modmul<T: ModOps<T>>(x: T, y: T, modulo: T) -> T {
Expand Down Expand Up @@ -304,6 +335,30 @@ mod test {
assert_eq!(a * s + b * t, g);
}

#[test]
fn modadd_returns_modadd() {
assert_eq!(6i64, modadd(-3, -5, 7));
assert_eq!(6u64, modadd(3, 10, 7));
assert_eq!(2_147_483_643i32, modadd(-1_073_741_824, -1_073_741_827, 2_147_483_647));
assert_eq!(2_147_483_643i64, modadd(-1_073_741_824, -1_073_741_827, 2_147_483_647));
assert_eq!(4i32, modadd(1_073_741_824, 1_073_741_827, 2_147_483_647));
assert_eq!(4i64, modadd(1_073_741_824, 1_073_741_827, 2_147_483_647));
assert_eq!(2_147_483_645i32, modadd(-2_147_483_648, -2_147_483_648, 2_147_483_647));
assert_eq!(2_147_483_645i64, modadd(-2_147_483_648, -2_147_483_648, 2_147_483_647));
}

#[test]
fn modsub_returns_modsub() {
assert_eq!(6i64, modsub(-3, 5, 7));
assert_eq!(6u64, modsub(3, 4, 7));
assert_eq!(2_147_483_643i32, modsub(-1_073_741_824, 1_073_741_827, 2_147_483_647));
assert_eq!(2_147_483_643i64, modsub(-1_073_741_824, 1_073_741_827, 2_147_483_647));
assert_eq!(4i32, modsub(1_073_741_824, -1_073_741_827, 2_147_483_647));
assert_eq!(4i64, modsub(1_073_741_824, -1_073_741_827, 2_147_483_647));
assert_eq!(4i32, modsub(-2_147_483_648, -5, 2_147_483_647));
assert_eq!(4i64, modsub(-2_147_483_648, -5, 2_147_483_647));
}

#[test]
fn modinv_returns_modinv() {
assert_eq!(None, modinv(4i64, 16i64));
Expand Down
2 changes: 2 additions & 0 deletions basm-std/src/math/ntt.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod linear_recurrence;
pub use linear_recurrence::linear_nth;
pub mod nttcore;
pub mod multiply;
pub use multiply::multiply_u64;
Expand Down
50 changes: 50 additions & 0 deletions basm-std/src/math/ntt/linear_recurrence.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use alloc::vec;
use crate::math::{modadd, modsub, modmul};
use super::{polymul_u64, polymod_u64};

/// Computes the `n`-th term `a[n]` of a linear recurrence specified by `first_terms` and `coeff`.
/// The recurrence is `a[k] = coeff[0] * a[k-1] + coeff[1] * a[k-2] + ... + coeff[m-1] * a[k-m-1]`
/// where `m` is the length of the `coeff` slice. Also, `a[i] = first_terms[i]` for `0 <= i < m`.
///
/// Checks are done to ensure that `first_terms.len() == coeff.len()` and that both are nonempty.
///
/// The result is computed in modulo `modulo`.
/// If `modulo` equals 0, it is treated as `2**64`.
/// Note that `modulo` does not need to be a prime.
///
/// Current implementation uses the Kitamasa algorithm along with the O(n lg n) NTT division.
/// This is subject to change (e.g., Bostan-Mori).
pub fn linear_nth(first_terms: &[u64], coeff: &[u64], mut n: u128, modulo: u64) -> u64 {
let m = first_terms.len();
assert!(m == coeff.len());
assert!(m > 0);
if modulo == 1 {
0
} else {
let mut p_base = vec![]; // The modulo base polynomial of Kitamasa
for x in coeff.iter().rev() {
p_base.push(if modulo == 0 { 0u64.wrapping_sub(modulo) } else { modsub(0, *x, modulo) });
}
p_base.push(1);
let mut p_pow2 = vec![0, 1];
let mut p_out = vec![1];
while n > 0 {
if (n & 1) != 0 {
p_out = polymod_u64(&polymul_u64(&p_pow2, &p_out, modulo), &p_base, modulo).unwrap();
}
p_pow2 = polymod_u64(&polymul_u64(&p_pow2, &p_pow2, modulo), &p_base, modulo).unwrap();
n >>= 1;
}
let mut ans = 0u64;
for i in 0..m {
if i >= p_out.len() { break; }
let term = if modulo == 0 {
first_terms[i].wrapping_mul(p_out[i])
} else {
modmul(first_terms[i], p_out[i], modulo)
};
ans = if modulo == 0 { ans.wrapping_add(term) } else { modadd(ans, term, modulo) };
}
ans
}
}

0 comments on commit 7de7ebe

Please sign in to comment.