diff --git a/src/uint/add.rs b/src/uint/add.rs index 21aa5d57..e4f7bfa4 100644 --- a/src/uint/add.rs +++ b/src/uint/add.rs @@ -24,12 +24,7 @@ impl Uint { /// Perform saturating addition, returning `MAX` on overflow. pub const fn saturating_add(&self, rhs: &Self) -> Self { let (res, overflow) = self.adc(rhs, Limb::ZERO); - - if overflow.0 == 0 { - res - } else { - Self::MAX - } + Self::ct_select(&res, &Self::MAX, CtChoice::from_lsb(overflow.0)) } /// Perform wrapping addition, discarding overflow. @@ -177,6 +172,16 @@ mod tests { assert_eq!(carry, Limb::ONE); } + #[test] + fn saturating_add_no_carry() { + assert_eq!(U128::ZERO.saturating_add(&U128::ONE), U128::ONE); + } + + #[test] + fn saturating_add_with_carry() { + assert_eq!(U128::MAX.saturating_add(&U128::ONE), U128::MAX); + } + #[test] fn wrapping_add_no_carry() { assert_eq!(U128::ZERO.wrapping_add(&U128::ONE), U128::ONE); diff --git a/src/uint/rand.rs b/src/uint/rand.rs index e283e224..c5f730b8 100644 --- a/src/uint/rand.rs +++ b/src/uint/rand.rs @@ -63,15 +63,17 @@ mod tests { let modulus = NonZero::new(U256::from(42u8)).unwrap(); let res = U256::random_mod(&mut rng, &modulus); - // Sanity check that the return value isn't zero - assert_ne!(res, U256::ZERO); + // Check that the value is in range + assert!(res >= U256::ZERO); + assert!(res < U256::from(42u8)); // Ensure `random_mod` runs in a reasonable amount of time // when the modulus is larger than 1 limb let modulus = NonZero::new(U256::from(0x10000000000000001u128)).unwrap(); let res = U256::random_mod(&mut rng, &modulus); - // Sanity check that the return value isn't zero - assert_ne!(res, U256::ZERO); + // Check that the value is in range + assert!(res >= U256::ZERO); + assert!(res < U256::from(0x10000000000000001u128)); } } diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index 56815e2d..5c96afb1 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -5,11 +5,20 @@ use crate::{Limb, Word}; use subtle::{ConstantTimeEq, CtOption}; impl Uint { + /// See [`Self::sqrt_vartime`]. + #[deprecated( + since = "0.5.3", + note = "This functionality will be moved to `sqrt_vartime` in a future release." + )] + pub const fn sqrt(&self) -> Self { + self.sqrt_vartime() + } + /// Computes √(`self`) /// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 /// /// Callers can check if `self` is a square by squaring the result - pub const fn sqrt(&self) -> Self { + pub const fn sqrt_vartime(&self) -> Self { let max_bits = (self.bits_vartime() + 1) >> 1; let cap = Self::ONE.shl_vartime(max_bits); let mut guess = cap; // ≥ √(`self`) @@ -47,17 +56,35 @@ impl Uint { Self::ct_select(&Self::ZERO, &guess, self.ct_is_nonzero()) } + /// See [`Self::wrapping_sqrt_vartime`]. + #[deprecated( + since = "0.5.3", + note = "This functionality will be moved to `wrapping_sqrt_vartime` in a future release." + )] + pub const fn wrapping_sqrt(&self) -> Self { + self.wrapping_sqrt_vartime() + } + /// Wrapped sqrt is just normal √(`self`) /// There’s no way wrapping could ever happen. /// This function exists, so that all operations are accounted for in the wrapping operations. - pub const fn wrapping_sqrt(&self) -> Self { - self.sqrt() + pub const fn wrapping_sqrt_vartime(&self) -> Self { + self.sqrt_vartime() + } + + /// See [`Self::checked_sqrt_vartime`]. + #[deprecated( + since = "0.5.3", + note = "This functionality will be moved to `checked_sqrt_vartime` in a future release." + )] + pub fn checked_sqrt(&self) -> CtOption { + self.checked_sqrt_vartime() } /// Perform checked sqrt, returning a [`CtOption`] which `is_some` /// only if the √(`self`)² == self - pub fn checked_sqrt(&self) -> CtOption { - let r = self.sqrt(); + pub fn checked_sqrt_vartime(&self) -> CtOption { + let r = self.sqrt_vartime(); let s = r.wrapping_mul(&r); CtOption::new(r, ConstantTimeEq::ct_eq(self, &s)) } @@ -76,13 +103,13 @@ mod tests { #[test] fn edge() { - assert_eq!(U256::ZERO.sqrt(), U256::ZERO); - assert_eq!(U256::ONE.sqrt(), U256::ONE); + assert_eq!(U256::ZERO.sqrt_vartime(), U256::ZERO); + assert_eq!(U256::ONE.sqrt_vartime(), U256::ONE); let mut half = U256::ZERO; for i in 0..half.limbs.len() / 2 { half.limbs[i] = Limb::MAX; } - assert_eq!(U256::MAX.sqrt(), half,); + assert_eq!(U256::MAX.sqrt_vartime(), half,); } #[test] @@ -104,22 +131,28 @@ mod tests { for (a, e) in &tests { let l = U256::from(*a); let r = U256::from(*e); - assert_eq!(l.sqrt(), r); - assert_eq!(l.checked_sqrt().is_some().unwrap_u8(), 1u8); + assert_eq!(l.sqrt_vartime(), r); + assert_eq!(l.checked_sqrt_vartime().is_some().unwrap_u8(), 1u8); } } #[test] fn nonsquares() { - assert_eq!(U256::from(2u8).sqrt(), U256::from(1u8)); - assert_eq!(U256::from(2u8).checked_sqrt().is_some().unwrap_u8(), 0); - assert_eq!(U256::from(3u8).sqrt(), U256::from(1u8)); - assert_eq!(U256::from(3u8).checked_sqrt().is_some().unwrap_u8(), 0); - assert_eq!(U256::from(5u8).sqrt(), U256::from(2u8)); - assert_eq!(U256::from(6u8).sqrt(), U256::from(2u8)); - assert_eq!(U256::from(7u8).sqrt(), U256::from(2u8)); - assert_eq!(U256::from(8u8).sqrt(), U256::from(2u8)); - assert_eq!(U256::from(10u8).sqrt(), U256::from(3u8)); + assert_eq!(U256::from(2u8).sqrt_vartime(), U256::from(1u8)); + assert_eq!( + U256::from(2u8).checked_sqrt_vartime().is_some().unwrap_u8(), + 0 + ); + assert_eq!(U256::from(3u8).sqrt_vartime(), U256::from(1u8)); + assert_eq!( + U256::from(3u8).checked_sqrt_vartime().is_some().unwrap_u8(), + 0 + ); + assert_eq!(U256::from(5u8).sqrt_vartime(), U256::from(2u8)); + assert_eq!(U256::from(6u8).sqrt_vartime(), U256::from(2u8)); + assert_eq!(U256::from(7u8).sqrt_vartime(), U256::from(2u8)); + assert_eq!(U256::from(8u8).sqrt_vartime(), U256::from(2u8)); + assert_eq!(U256::from(10u8).sqrt_vartime(), U256::from(3u8)); } #[cfg(feature = "rand")] @@ -130,15 +163,15 @@ mod tests { let t = rng.next_u32() as u64; let s = U256::from(t); let s2 = s.checked_mul(&s).unwrap(); - assert_eq!(s2.sqrt(), s); - assert_eq!(s2.checked_sqrt().is_some().unwrap_u8(), 1); + assert_eq!(s2.sqrt_vartime(), s); + assert_eq!(s2.checked_sqrt_vartime().is_some().unwrap_u8(), 1); } for _ in 0..50 { let s = U256::random(&mut rng); let mut s2 = U512::ZERO; s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs); - assert_eq!(s.square().sqrt(), s2); + assert_eq!(s.square().sqrt_vartime(), s2); } } } diff --git a/src/uint/sub.rs b/src/uint/sub.rs index c39e5492..571dd6aa 100644 --- a/src/uint/sub.rs +++ b/src/uint/sub.rs @@ -25,12 +25,7 @@ impl Uint { /// Perform saturating subtraction, returning `ZERO` on underflow. pub const fn saturating_sub(&self, rhs: &Self) -> Self { let (res, underflow) = self.sbb(rhs, Limb::ZERO); - - if underflow.0 == 0 { - res - } else { - Self::ZERO - } + Self::ct_select(&res, &Self::ZERO, CtChoice::from_mask(underflow.0)) } /// Perform wrapping subtraction, discarding underflow and wrapping around @@ -180,6 +175,22 @@ mod tests { assert_eq!(borrow, Limb::MAX); } + #[test] + fn saturating_sub_no_borrow() { + assert_eq!( + U128::from(5u64).saturating_sub(&U128::ONE), + U128::from(4u64) + ); + } + + #[test] + fn saturating_sub_with_borrow() { + assert_eq!( + U128::from(4u64).saturating_sub(&U128::from(5u64)), + U128::ZERO + ); + } + #[test] fn wrapping_sub_no_borrow() { assert_eq!(U128::ONE.wrapping_sub(&U128::ONE), U128::ZERO); diff --git a/tests/proptests.rs b/tests/proptests.rs index 572d990d..1136c633 100644 --- a/tests/proptests.rs +++ b/tests/proptests.rs @@ -190,7 +190,7 @@ proptest! { fn wrapping_sqrt(a in uint()) { let a_bi = to_biguint(&a); let expected = to_uint(a_bi.sqrt()); - let actual = a.wrapping_sqrt(); + let actual = a.wrapping_sqrt_vartime(); assert_eq!(expected, actual); }