Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement sqrt #9

Merged
merged 1 commit into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion proptests/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions proptests/tests/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,13 @@ proptest! {
assert_eq!(expected, actual);
}
}

#[test]
fn wrapping_sqrt(a in uint()) {
let a_bi = to_biguint(&a);
let expected = to_uint(a_bi.sqrt());
let actual = a.wrapping_sqrt();

assert_eq!(expected, actual);
}
}
11 changes: 10 additions & 1 deletion src/limb/cmp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Limb comparisons

use super::{Inner, Limb, SignedInner, HI_BIT};
use super::{Inner, Limb, SignedInner, SignedWide, BIT_SIZE, HI_BIT};
use core::cmp::Ordering;
use subtle::{Choice, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess};

Expand Down Expand Up @@ -38,6 +38,15 @@ impl Limb {
let inner = self.0 as SignedInner;
((inner | inner.saturating_neg()) >> HI_BIT) as Inner
}

#[inline]
pub(crate) const fn ct_cmp(lhs: Self, rhs: Self) -> SignedInner {
let a = lhs.0 as SignedWide;
let b = rhs.0 as SignedWide;
let gt = ((b - a) >> BIT_SIZE) & 1;
let lt = ((a - b) >> BIT_SIZE) & 1 & !gt;
(gt as SignedInner) - (lt as SignedInner)
}
}

impl ConstantTimeEq for Limb {
Expand Down
1 change: 1 addition & 0 deletions src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod mul;
mod neg_mod;
mod shl;
mod shr;
mod sqrt;
mod sub;
mod sub_mod;

Expand Down
8 changes: 4 additions & 4 deletions src/uint/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,31 @@ impl<const LIMBS: usize> UInt<LIMBS> {
/// Wrapped division is just normal division i.e. `self` / `rhs`
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
pub const fn wrapping_div(self, rhs: &Self) -> Self {
pub const fn wrapping_div(&self, rhs: &Self) -> Self {
let (q, _, c) = self.ct_div_rem(rhs);
const_assert!(c == 1, "divide by zero");
q
}

/// Perform checked division, returning a [`CtOption`] which `is_some`
/// only if the rhs != 0
pub fn checked_div(self, rhs: &Self) -> CtOption<Self> {
pub fn checked_div(&self, rhs: &Self) -> CtOption<Self> {
let (q, _, c) = self.ct_div_rem(rhs);
CtOption::new(q, Choice::from(c))
}

/// Wrapped (modular) remainder calculation is just `self` % `rhs`.
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
pub const fn wrapping_rem(self, rhs: &Self) -> Self {
pub const fn wrapping_rem(&self, rhs: &Self) -> Self {
let (r, c) = self.ct_reduce(rhs);
const_assert!(c == 1, "modulo zero");
r
}

/// Perform checked reduction, returning a [`CtOption`] which `is_some`
/// only if the rhs != 0
pub fn checked_rem(self, rhs: &Self) -> CtOption<Self> {
pub fn checked_rem(&self, rhs: &Self) -> CtOption<Self> {
let (r, c) = self.ct_reduce(rhs);
CtOption::new(r, Choice::from(c))
}
Expand Down
141 changes: 141 additions & 0 deletions src/uint/sqrt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
//! [`UInt`] square root operations.

use super::UInt;
use crate::limb::Inner;
use crate::Limb;
use subtle::{ConstantTimeEq, CtOption};

impl<const LIMBS: usize> UInt<LIMBS> {
/// Computes √(`self`)
/// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
///
/// Callers can check if `self` is a square by squaring the result
pub const fn sqrt(&self) -> Self {
let max_bits = ((self.bits() + 1) >> 1) as usize;
let cap = Self::ONE.shl_vartime(max_bits);
let mut guess = cap; // ≥ √(`self`)
let mut xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};

// If guess increased, the initial guess was low.
// Repeat until reverse course.
while guess.ct_cmp(&xn) == -1 {
// Sometimes an increase is too far, especially with large
// powers, and then takes a long time to walk back. The upper
// bound is based on bit size, so saturate on that.
let res = Limb::ct_cmp(Limb(xn.bits() as Inner), Limb(max_bits as Inner)) - 1;
let le = Limb::is_nonzero(Limb(res as Inner));
guess = Self::ct_select(cap, xn, le);
xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}

// Repeat while guess decreases.
while guess.ct_cmp(&xn) == 1 && xn.ct_is_nonzero() == Inner::MAX {
guess = xn;
xn = {
let q = self.wrapping_div(&guess);
let t = guess.wrapping_add(&q);
t.shr_vartime(1)
};
}

Self::ct_select(Self::ZERO, guess, self.ct_is_nonzero())
}

/// Wrapped sqrt is just normal √(`self`)
/// There’s no way wrapping could ever happen.
/// This function exists, so that all operations are accounted for in the wrapping operations.
pub const fn wrapping_sqrt(&self) -> Self {
self.sqrt()
}

/// Perform checked sqrt, returning a [`CtOption`] which `is_some`
/// only if the √(`self`)² == self
pub fn checked_sqrt(&self) -> CtOption<Self> {
let r = self.sqrt();
let s = r.wrapping_mul(&r);
CtOption::new(r, self.ct_eq(&s))
}
}

#[cfg(test)]
mod tests {
use crate::U512;
use crate::{Limb, U256};
use rand_chacha::ChaChaRng;
use rand_core::{RngCore, SeedableRng};

#[test]
fn edge() {
assert_eq!(U256::ZERO.sqrt(), U256::ZERO);
assert_eq!(U256::ONE.sqrt(), U256::ONE);
let mut half = U256::ZERO;
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
assert_eq!(U256::MAX.sqrt(), half,);
}

#[test]
fn simple() {
let tests = [
(4u8, 2u8),
(9, 3),
(16, 4),
(25, 5),
(36, 6),
(49, 7),
(64, 8),
(81, 9),
(100, 10),
(121, 11),
(144, 12),
(169, 13),
];
for (a, e) in &tests {
let l = U256::from(*a);
let r = U256::from(*e);
assert_eq!(l.sqrt(), r);
assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8);
}
}

#[test]
fn nonsquares() {
assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8));
assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0);
assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8));
assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8));
}

#[test]
fn fuzz() {
let mut rng = ChaChaRng::from_seed([7u8; 32]);
for _ in 0..50 {
let t = rng.next_u32() as u64;
let s = U256::from(t);
let s2 = s.checked_mul(&s).unwrap();
assert_eq!(s2.sqrt(), s);
assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1);
}

for _ in 0..50 {
let s = U256::random(&mut rng);
let mut s2 = U512::ZERO;
s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs);
assert_eq!(s.square().sqrt(), s2);
}
}
}