diff --git a/benches/bench.rs b/benches/bench.rs index 8be5f592..26ee63a4 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -3,7 +3,7 @@ use criterion::{ }; use crypto_bigint::{ modular::runtime_mod::{DynResidue, DynResidueParams}, - Limb, NonZero, Random, Reciprocal, U128, U256, + Limb, NonZero, Random, Reciprocal, U128, U2048, U256, }; use rand_core::OsRng; @@ -131,6 +131,28 @@ fn bench_montgomery_conversion(group: &mut BenchmarkGroup<'_, M> }); } +fn bench_shifts(group: &mut BenchmarkGroup<'_, M>) { + group.bench_function("shl_vartime, small, U2048", |b| { + b.iter_batched(|| U2048::ONE, |x| x.shl_vartime(10), BatchSize::SmallInput) + }); + + group.bench_function("shl_vartime, large, U2048", |b| { + b.iter_batched( + || U2048::ONE, + |x| x.shl_vartime(1024 + 10), + BatchSize::SmallInput, + ) + }); + + group.bench_function("shl, U2048", |b| { + b.iter_batched(|| U2048::ONE, |x| x.shl(1024 + 10), BatchSize::SmallInput) + }); + + group.bench_function("shr, U2048", |b| { + b.iter_batched(|| U2048::ONE, |x| x.shr(1024 + 10), BatchSize::SmallInput) + }); +} + fn bench_wrapping_ops(c: &mut Criterion) { let mut group = c.benchmark_group("wrapping ops"); bench_division(&mut group); @@ -144,5 +166,17 @@ fn bench_montgomery(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, bench_wrapping_ops, bench_montgomery); +fn bench_modular_ops(c: &mut Criterion) { + let mut group = c.benchmark_group("modular ops"); + bench_shifts(&mut group); + group.finish(); +} + +criterion_group!( + benches, + bench_wrapping_ops, + bench_montgomery, + bench_modular_ops +); + criterion_main!(benches); diff --git a/src/ct_choice.rs b/src/ct_choice.rs index d0d31dc9..3ad2d232 100644 --- a/src/ct_choice.rs +++ b/src/ct_choice.rs @@ -29,6 +29,12 @@ impl CtChoice { Self(value.wrapping_neg()) } + /// Returns the truthy value if `x < y`, and the falsy value otherwise. + pub(crate) const fn from_usize_lt(x: usize, y: usize) -> Self { + let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (usize::BITS - 1); + Self::from_lsb(bit as Word) + } + pub(crate) const fn not(&self) -> Self { Self(!self.0) } diff --git a/src/limb/shl.rs b/src/limb/shl.rs index c0003ddb..88e37f01 100644 --- a/src/limb/shl.rs +++ b/src/limb/shl.rs @@ -5,6 +5,7 @@ use core::ops::{Shl, ShlAssign}; impl Limb { /// Computes `self << rhs`. + /// Panics if `rhs` overflows `Limb::BITS`. #[inline(always)] pub const fn shl(self, rhs: Self) -> Self { Limb(self.0 << rhs.0) diff --git a/src/limb/shr.rs b/src/limb/shr.rs index 29ee7635..7c422e04 100644 --- a/src/limb/shr.rs +++ b/src/limb/shr.rs @@ -5,6 +5,7 @@ use core::ops::{Shr, ShrAssign}; impl Limb { /// Computes `self >> rhs`. + /// Panics if `rhs` overflows `Limb::BITS`. #[inline(always)] pub const fn shr(self, rhs: Self) -> Self { Limb(self.0 >> rhs.0) diff --git a/src/uint.rs b/src/uint.rs index 341685d6..21238cf2 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -92,6 +92,10 @@ impl Uint { /// Total size of the represented integer in bits. pub const BITS: usize = LIMBS * Limb::BITS; + /// Bit size of `BITS`. + // Note: assumes the type of `BITS` is `usize`. Any way to assert that? + pub(crate) const LOG2_BITS: usize = (usize::BITS - Self::BITS.leading_zeros()) as usize; + /// Total size of the represented integer in bytes. pub const BYTES: usize = LIMBS * Limb::BYTES; diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 9fcf1c8a..1dbc40f0 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -1,6 +1,6 @@ //! [`Uint`] bitwise left shift operations. -use crate::{Limb, Uint, Word}; +use crate::{CtChoice, Limb, Uint, Word}; use core::ops::{Shl, ShlAssign}; impl Uint { @@ -78,6 +78,22 @@ impl Uint { (new_lower, upper) } + + /// Computes `self << n`. + /// Returns zero if `n >= Self::BITS`. + pub const fn shl(&self, shift: usize) -> Self { + let overflow = CtChoice::from_usize_lt(shift, Self::BITS).not(); + let shift = shift % Self::BITS; + let mut result = *self; + let mut i = 0; + while i < Self::LOG2_BITS { + let bit = CtChoice::from_lsb((shift as Word >> i) & 1); + result = Uint::ct_select(&result, &result.shl_vartime(1 << i), bit); + i += 1; + } + + Uint::ct_select(&result, &Self::ZERO, overflow) + } } impl Shl for Uint { @@ -88,7 +104,7 @@ impl Shl for Uint { /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. fn shl(self, rhs: usize) -> Uint { - self.shl_vartime(rhs) + Uint::::shl(&self, rhs) } } @@ -100,7 +116,7 @@ impl Shl for &Uint { /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. fn shl(self, rhs: usize) -> Uint { - self.shl_vartime(rhs) + self.shl(rhs) } } @@ -110,7 +126,7 @@ impl ShlAssign for Uint { /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. fn shl_assign(&mut self, rhs: usize) { - *self = self.shl_vartime(rhs) + *self = self.shl(rhs) } } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index 3698b970..6a36fbe8 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -1,7 +1,7 @@ //! [`Uint`] bitwise right shift operations. use super::Uint; -use crate::{limb::HI_BIT, CtChoice, Limb}; +use crate::{limb::HI_BIT, CtChoice, Limb, Word}; use core::ops::{Shr, ShrAssign}; impl Uint { @@ -97,6 +97,22 @@ impl Uint { (lower, new_upper) } + + /// Computes `self << n`. + /// Returns zero if `n >= Self::BITS`. + pub const fn shr(&self, shift: usize) -> Self { + let overflow = CtChoice::from_usize_lt(shift, Self::BITS).not(); + let shift = shift % Self::BITS; + let mut result = *self; + let mut i = 0; + while i < Self::LOG2_BITS { + let bit = CtChoice::from_lsb((shift as Word >> i) & 1); + result = Uint::ct_select(&result, &result.shr_vartime(1 << i), bit); + i += 1; + } + + Uint::ct_select(&result, &Self::ZERO, overflow) + } } impl Shr for Uint { @@ -107,7 +123,7 @@ impl Shr for Uint { /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. fn shr(self, rhs: usize) -> Uint { - self.shr_vartime(rhs) + Uint::::shr(&self, rhs) } } @@ -119,13 +135,13 @@ impl Shr for &Uint { /// When used with a fixed `rhs`, this function is constant-time with respect /// to `self`. fn shr(self, rhs: usize) -> Uint { - self.shr_vartime(rhs) + self.shr(rhs) } } impl ShrAssign for Uint { fn shr_assign(&mut self, rhs: usize) { - *self = self.shr_vartime(rhs); + *self = self.shr(rhs); } } diff --git a/tests/proptests.rs b/tests/proptests.rs index 1136c633..b87c7011 100644 --- a/tests/proptests.rs +++ b/tests/proptests.rs @@ -6,7 +6,7 @@ use crypto_bigint::{ }; use num_bigint::BigUint; use num_integer::Integer; -use num_traits::identities::Zero; +use num_traits::identities::{One, Zero}; use proptest::prelude::*; use std::mem; @@ -59,6 +59,32 @@ proptest! { assert_eq!(expected, actual); } + #[test] + fn shl(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = (shift as usize) % (U256::BITS * 2); + + let expected = to_uint((a_bi << shift) & ((BigUint::one() << U256::BITS) - BigUint::one())); + let actual = a.shl(shift); + + assert_eq!(expected, actual); + } + + #[test] + fn shr(a in uint(), shift in any::()) { + let a_bi = to_biguint(&a); + + // Add a 50% probability of overflow. + let shift = (shift as usize) % (U256::BITS * 2); + + let expected = to_uint(a_bi >> shift); + let actual = a.shr(shift); + + assert_eq!(expected, actual); + } + #[test] fn wrapping_add(a in uint(), b in uint()) { let a_bi = to_biguint(&a);