Skip to content

Commit

Permalink
Add constant-time Uint::shr() and Uint::shl()
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Aug 28, 2023
1 parent 783603c commit 708e0c8
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 11 deletions.
38 changes: 36 additions & 2 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use criterion::{
};
use crypto_bigint::{
modular::runtime_mod::{DynResidue, DynResidueParams},
Limb, NonZero, Random, Reciprocal, U128, U256,
Limb, NonZero, Random, Reciprocal, U128, U2048, U256,
};
use rand_core::OsRng;

Expand Down Expand Up @@ -131,6 +131,28 @@ fn bench_montgomery_conversion<M: Measurement>(group: &mut BenchmarkGroup<'_, M>
});
}

fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
group.bench_function("shl_vartime, small, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shl_vartime(10), BatchSize::SmallInput)
});

group.bench_function("shl_vartime, large, U2048", |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.shl_vartime(1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shl, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shl(1024 + 10), BatchSize::SmallInput)
});

group.bench_function("shr, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shr(1024 + 10), BatchSize::SmallInput)
});
}

fn bench_wrapping_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");
bench_division(&mut group);
Expand All @@ -144,5 +166,17 @@ fn bench_montgomery(c: &mut Criterion) {
group.finish();
}

criterion_group!(benches, bench_wrapping_ops, bench_montgomery);
fn bench_modular_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("modular ops");
bench_shifts(&mut group);
group.finish();
}

criterion_group!(
benches,
bench_wrapping_ops,
bench_montgomery,
bench_modular_ops
);

criterion_main!(benches);
6 changes: 6 additions & 0 deletions src/ct_choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ impl CtChoice {
Self(value.wrapping_neg())
}

/// Returns the truthy value if `x < y`, and the falsy value otherwise.
pub(crate) const fn from_usize_lt(x: usize, y: usize) -> Self {
let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (usize::BITS - 1);
Self::from_lsb(bit as Word)
}

pub(crate) const fn not(&self) -> Self {
Self(!self.0)
}
Expand Down
1 change: 1 addition & 0 deletions src/limb/shl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::{Shl, ShlAssign};

impl Limb {
/// Computes `self << rhs`.
/// Panics if `rhs` overflows `Limb::BITS`.
#[inline(always)]
pub const fn shl(self, rhs: Self) -> Self {
Limb(self.0 << rhs.0)
Expand Down
1 change: 1 addition & 0 deletions src/limb/shr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::{Shr, ShrAssign};

impl Limb {
/// Computes `self >> rhs`.
/// Panics if `rhs` overflows `Limb::BITS`.
#[inline(always)]
pub const fn shr(self, rhs: Self) -> Self {
Limb(self.0 >> rhs.0)
Expand Down
4 changes: 4 additions & 0 deletions src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ impl<const LIMBS: usize> Uint<LIMBS> {
/// Total size of the represented integer in bits.
pub const BITS: usize = LIMBS * Limb::BITS;

/// Bit size of `BITS`.
// Note: assumes the type of `BITS` is `usize`. Any way to assert that?
pub(crate) const LOG2_BITS: usize = (usize::BITS - Self::BITS.leading_zeros()) as usize;

/// Total size of the represented integer in bytes.
pub const BYTES: usize = LIMBS * Limb::BYTES;

Expand Down
24 changes: 20 additions & 4 deletions src/uint/shl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! [`Uint`] bitwise left shift operations.
use crate::{Limb, Uint, Word};
use crate::{CtChoice, Limb, Uint, Word};
use core::ops::{Shl, ShlAssign};

impl<const LIMBS: usize> Uint<LIMBS> {
Expand Down Expand Up @@ -78,6 +78,22 @@ impl<const LIMBS: usize> Uint<LIMBS> {

(new_lower, upper)
}

/// Computes `self << n`.
/// Returns zero if `n >= Self::BITS`.
pub const fn shl(&self, shift: usize) -> Self {
let overflow = CtChoice::from_usize_lt(shift, Self::BITS).not();
let shift = shift % Self::BITS;
let mut result = *self;
let mut i = 0;
while i < Self::LOG2_BITS {
let bit = CtChoice::from_lsb((shift as Word >> i) & 1);
result = Uint::ct_select(&result, &result.shl_vartime(1 << i), bit);
i += 1;
}

Uint::ct_select(&result, &Self::ZERO, overflow)
}
}

impl<const LIMBS: usize> Shl<usize> for Uint<LIMBS> {
Expand All @@ -88,7 +104,7 @@ impl<const LIMBS: usize> Shl<usize> for Uint<LIMBS> {
/// When used with a fixed `rhs`, this function is constant-time with respect
/// to `self`.
fn shl(self, rhs: usize) -> Uint<LIMBS> {
self.shl_vartime(rhs)
Uint::<LIMBS>::shl(&self, rhs)
}
}

Expand All @@ -100,7 +116,7 @@ impl<const LIMBS: usize> Shl<usize> for &Uint<LIMBS> {
/// When used with a fixed `rhs`, this function is constant-time with respect
/// to `self`.
fn shl(self, rhs: usize) -> Uint<LIMBS> {
self.shl_vartime(rhs)
self.shl(rhs)
}
}

Expand All @@ -110,7 +126,7 @@ impl<const LIMBS: usize> ShlAssign<usize> for Uint<LIMBS> {
/// When used with a fixed `rhs`, this function is constant-time with respect
/// to `self`.
fn shl_assign(&mut self, rhs: usize) {
*self = self.shl_vartime(rhs)
*self = self.shl(rhs)
}
}

Expand Down
24 changes: 20 additions & 4 deletions src/uint/shr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! [`Uint`] bitwise right shift operations.
use super::Uint;
use crate::{limb::HI_BIT, CtChoice, Limb};
use crate::{limb::HI_BIT, CtChoice, Limb, Word};
use core::ops::{Shr, ShrAssign};

impl<const LIMBS: usize> Uint<LIMBS> {
Expand Down Expand Up @@ -97,6 +97,22 @@ impl<const LIMBS: usize> Uint<LIMBS> {

(lower, new_upper)
}

/// Computes `self << n`.
/// Returns zero if `n >= Self::BITS`.
pub const fn shr(&self, shift: usize) -> Self {
let overflow = CtChoice::from_usize_lt(shift, Self::BITS).not();
let shift = shift % Self::BITS;
let mut result = *self;
let mut i = 0;
while i < Self::LOG2_BITS {
let bit = CtChoice::from_lsb((shift as Word >> i) & 1);
result = Uint::ct_select(&result, &result.shr_vartime(1 << i), bit);
i += 1;
}

Uint::ct_select(&result, &Self::ZERO, overflow)
}
}

impl<const LIMBS: usize> Shr<usize> for Uint<LIMBS> {
Expand All @@ -107,7 +123,7 @@ impl<const LIMBS: usize> Shr<usize> for Uint<LIMBS> {
/// When used with a fixed `rhs`, this function is constant-time with respect
/// to `self`.
fn shr(self, rhs: usize) -> Uint<LIMBS> {
self.shr_vartime(rhs)
Uint::<LIMBS>::shr(&self, rhs)
}
}

Expand All @@ -119,13 +135,13 @@ impl<const LIMBS: usize> Shr<usize> for &Uint<LIMBS> {
/// When used with a fixed `rhs`, this function is constant-time with respect
/// to `self`.
fn shr(self, rhs: usize) -> Uint<LIMBS> {
self.shr_vartime(rhs)
self.shr(rhs)
}
}

impl<const LIMBS: usize> ShrAssign<usize> for Uint<LIMBS> {
fn shr_assign(&mut self, rhs: usize) {
*self = self.shr_vartime(rhs);
*self = self.shr(rhs);
}
}

Expand Down
28 changes: 27 additions & 1 deletion tests/proptests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crypto_bigint::{
};
use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::identities::Zero;
use num_traits::identities::{One, Zero};
use proptest::prelude::*;
use std::mem;

Expand Down Expand Up @@ -59,6 +59,32 @@ proptest! {
assert_eq!(expected, actual);
}

#[test]
fn shl(a in uint(), shift in any::<u16>()) {
let a_bi = to_biguint(&a);

// Add a 50% probability of overflow.
let shift = (shift as usize) % (U256::BITS * 2);

let expected = to_uint((a_bi << shift) & ((BigUint::one() << U256::BITS) - BigUint::one()));
let actual = a.shl(shift);

assert_eq!(expected, actual);
}

#[test]
fn shr(a in uint(), shift in any::<u16>()) {
let a_bi = to_biguint(&a);

// Add a 50% probability of overflow.
let shift = (shift as usize) % (U256::BITS * 2);

let expected = to_uint(a_bi >> shift);
let actual = a.shr(shift);

assert_eq!(expected, actual);
}

#[test]
fn wrapping_add(a in uint(), b in uint()) {
let a_bi = to_biguint(&a);
Expand Down

0 comments on commit 708e0c8

Please sign in to comment.