Skip to content

Commit

Permalink
mont perf
Browse files Browse the repository at this point in the history
  • Loading branch information
eschorn1 committed May 10, 2024
1 parent ec667f3 commit 63a2c94
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 142 deletions.
32 changes: 16 additions & 16 deletions benches/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,28 @@ Near-obvious uplift can be had with more careful modular multiplication & additi
using fewer reductions. Also, 'u16' arithmetic has a performance penalty.

~~~
May 5, 2024
May 10, 2024
Intel® Core™ i7-7700K CPU @ 4.20GHz × 8 Circa 2017 w/ Rust 1.77
$ RUSTFLAGS="-C target-cpu=native" cargo bench
ml_dsa_44 keygen time: [85.256 µs 85.275 µs 85.299 µs]
ml_dsa_65 keygen time: [160.99 µs 161.04 µs 161.10 µs]
ml_dsa_87 keygen time: [233.85 µs 233.92 µs 233.99 µs]
ml_dsa_44 keygen time: [85.502 µs 85.521 µs 85.543 µs]
ml_dsa_65 keygen time: [162.17 µs 162.23 µs 162.34 µs]
ml_dsa_87 keygen time: [232.24 µs 232.26 µs 232.28 µs]
ml_dsa_44 sk sign time: [306.79 µs 309.50 µs 312.14 µs]
ml_dsa_65 sk sign time: [519.89 µs 525.77 µs 531.68 µs]
ml_dsa_87 sk sign time: [638.39 µs 645.71 µs 653.05 µs]
ml_dsa_44 sk sign time: [301.10 µs 303.62 µs 306.12 µs]
ml_dsa_65 sk sign time: [486.54 µs 491.62 µs 496.69 µs]
ml_dsa_87 sk sign time: [593.29 µs 599.69 µs 606.20 µs]
ml_dsa_44 esk sign time: [247.20 µs 250.01 µs 252.94 µs]
ml_dsa_65 esk sign time: [423.54 µs 429.68 µs 435.97 µs]
ml_dsa_87 esk sign time: [453.37 µs 458.29 µs 463.27 µs]
ml_dsa_44 esk sign time: [233.84 µs 236.39 µs 239.00 µs]
ml_dsa_65 esk sign time: [375.86 µs 380.61 µs 385.48 µs]
ml_dsa_87 esk sign time: [401.26 µs 406.48 µs 411.66 µs]
ml_dsa 44 pk verify time: [75.202 µs 75.216 µs 75.231 µs]
ml_dsa 65 pk verify time: [135.17 µs 135.19 µs 135.22 µs]
ml_dsa 87 pk verify time: [224.04 µs 224.18 µs 224.35 µs]
ml_dsa 44 pk verify time: [78.619 µs 78.630 µs 78.640 µs]
ml_dsa 65 pk verify time: [130.59 µs 130.64 µs 130.69 µs]
ml_dsa 87 pk verify time: [219.01 µs 219.06 µs 219.12 µs]
ml_dsa 44 epk verify time: [22.837 µs 22.847 µs 22.856 µs]
ml_dsa 65 epk verify time: [38.911 µs 38.923 µs 38.934 µs]
ml_dsa 87 epk verify time: [56.317 µs 56.346 µs 56.374 µs]
ml_dsa 44 epk verify time: [20.677 µs 20.694 µs 20.712 µs]
ml_dsa 65 epk verify time: [26.972 µs 26.980 µs 26.987 µs]
ml_dsa 87 epk verify time: [36.188 µs 36.203 µs 36.218 µs]
~~~
25 changes: 0 additions & 25 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,29 +467,4 @@ mod tests {
simple_bit_pack(&r, (1 << 6) - 1, &mut random_bytes);
// no panic is good news
}

// #[test]
// #[should_panic]
// #[allow(clippy::should_panic_without_expect)]
// fn test_simple_bit_pack_validation2() {
// let mut random_bytes = [0u8; 32 * 7];
// rand::thread_rng().fill_bytes(&mut random_bytes);
// // wrong size r coeff
// let r = [1024i32; 256];
// simple_bit_pack(&r, (1 << 6) - 1, &mut random_bytes);
// // should have paniced by now...
// }

// TODO: reword to start with bit_pack..
// #[test]
// fn test_bit_pack_roundtrip() {
// // Round trip for 32 * 6(bitlen) bytes
// let random_bytes: Vec<u8> = (0..32 * 6).map(|_| rand::random::<u8>()).collect();
// let mut r = bit_unpack(&random_bytes, 1 << 2, (1 << 6) - (1 << 2) - 1).unwrap();
// let mut res = [0u8; 32 * 6];
// bit_pack(&r, 1 << 2, (1 << 6) - (1 << 2) - 1, &mut res);
// assert_eq!(random_bytes, res);
// }

// TODO test hint_bit_pack and hint_bit_unpack
}
13 changes: 2 additions & 11 deletions src/encodings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ mod tests {

fn get_vec(max: u32) -> R {
let mut rnd_r = R0; //[0i32; 256];
rnd_r.0
rnd_r
.0
.iter_mut()
.for_each(|e| *e = rand::random::<i32>().rem_euclid(i32::try_from(max).unwrap()));
rnd_r
Expand All @@ -413,11 +414,6 @@ mod tests {
#[test]
#[allow(clippy::similar_names)]
fn test_sk_encode_decode_roundtrip1() {
// TODO: figure out how to best test this correctly
// - should the skDecode function return a result (probably)
// - double check the range of the input operands (most are +/- ETA, but last one is 2^d-1)
// - maybe need to rework one/two of the conversion functions in a similar fashion

// D=13 ETA=2 K=4 L=4 SK_LEN=2560
let (rho, k) = (rand::random::<[u8; 32]>(), rand::random::<[u8; 32]>());
let mut tr = [0u8; 64];
Expand Down Expand Up @@ -453,14 +449,9 @@ mod tests {
rand::thread_rng().fill_bytes(&mut c_tilde);
let z = [get_vec(2), get_vec(2), get_vec(2), get_vec(2)];
let h = [get_vec(1), get_vec(1), get_vec(1), get_vec(1)];
//let mut sigma = [0u8; 2420];
let sigma = sig_encode::<4, 4, { 128 / 4 }, 2420>(1 << 17, 80, &c_tilde.clone(), &z, &h);
// let mut c_test = [0u8; 2 * 128 / 8];
// let mut z_test = [[0i32; 256]; 4];
// let mut h_test = [[0i32; 256]; 4];
let (c_test, z_test, h_test) =
sig_decode::<4, 4, { 128 / 4 }, 2420>(1 << 17, 80, &sigma).unwrap();
// assert!(res.is_ok());
assert_eq!(c_tilde[0..8], c_test[0..8]);
assert_eq!(z, z_test);
assert_eq!(h, h_test.unwrap());
Expand Down
38 changes: 17 additions & 21 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,24 @@ pub(crate) const fn bit_length(a: i32) -> usize { a.ilog2() as usize + 1 }
/// element m′ ∈ Z in the range −α/2 < m′ ≤ α/2 such that m and m′ are congruent
/// modulo α. 'ready to optimize'
pub(crate) fn center_mod(m: i32) -> i32 {
let t = full_reduce32(m);
let t = partial_reduce32(m);
let over2 = (Q / 2) - t; // check if t is larger than Q/2
t - ((over2 >> 31) & Q) // sub Q if over2 is negative
}


/// Matrix by vector multiplication; See top of page 10, first row: `w_hat` = `A_hat` mul `u_hat`
#[must_use] // TODO: MONT?!?!???
#[must_use]
pub(crate) fn mat_vec_mul<const K: usize, const L: usize>(
a_hat: &[[T; L]; K], u_hat: &[T; L],
) -> [T; K] {
let mut w_hat = [T0; K];
let u_hat_mont = to_mont(u_hat);
for i in 0..K {
#[allow(clippy::needless_range_loop)] // clarity
for j in 0..L {
w_hat[i].0.iter_mut().enumerate().for_each(|(m, e)| {
*e = partial_reduce64(
i64::from(*e) + i64::from(a_hat[i][j].0[m]) * i64::from(u_hat[j].0[m]),
);
w_hat[i].0.iter_mut().enumerate().for_each(|(n, e)| {
*e += mont_reduce(i64::from(a_hat[i][j].0[n]) * i64::from(u_hat_mont[j].0[n]));
});
}
}
Expand All @@ -104,6 +103,17 @@ pub(crate) fn vec_add<const K: usize>(vec_a: &[R; K], vec_b: &[R; K]) -> [R; K]
}


#[allow(clippy::cast_possible_truncation)] // as i32
pub(crate) fn to_mont<const L: usize>(vec_a: &[T; L]) -> [T; L] {
let result: [T; L] = core::array::from_fn(|l| {
T(core::array::from_fn(|n| {
partial_reduce64(i64::from(vec_a[l].0[n]).wrapping_mul(1 << 32))
}))
});
result
}


pub(crate) fn infinity_norm<const ROW: usize>(w: &[R; ROW]) -> i32 {
let mut result = 0; // no early exit
for row in w {
Expand Down Expand Up @@ -137,25 +147,12 @@ const fn pow_mod_q(g: i32, e: u8) -> i32 {
}


// const fn gen_zeta_table() -> [i32; 256] {
// let mut result = [0i32; 256];
// let mut i = 0;
// while i < 256 {
// result[i] = pow_mod_q(ZETA, i.to_le_bytes()[0].reverse_bits());
// i += 1;
// }
// result
// }

// #[allow(dead_code)]
// pub(crate) static ZETA_TABLE: [i32; 256] = gen_zeta_table();

///////////////////////

#[allow(dead_code)]
const QINV: i64 = 58_728_449; // (Q * QINV) % 2**32 = 1

#[allow(dead_code, clippy::cast_possible_truncation)]
#[allow(clippy::cast_possible_truncation)]
pub(crate) const fn mont_reduce(a: i64) -> i32 {
let t = a.wrapping_mul(QINV) as i32;
let t = (a - (t as i64).wrapping_mul(Q as i64)) >> 32;
Expand All @@ -164,7 +161,6 @@ pub(crate) const fn mont_reduce(a: i64) -> i32 {
t as i32
}

#[allow(dead_code)]
pub(crate) static ZETA_TABLE_MONT: [i32; 256] = gen_zeta_table_mont();

#[allow(clippy::cast_possible_truncation)]
Expand Down
7 changes: 5 additions & 2 deletions src/high_low.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ pub(crate) fn power2round<const K: usize>(t: &[R; K]) -> ([R; K], [R; K]) {
// 1: r+ ← r mod q
// 2: r0 ← r+ mod±2^d
// 3: return ((r+ − r0)/2^d, r0)
let r_1: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| (t[k].0[n] + (1 << (D - 1)) - 1) >> D)));
let r_0: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| t[k].0[n] - (r_1[k].0[n] << D))));
let r_1: [R; K] = core::array::from_fn(|k| {
R(core::array::from_fn(|n| (t[k].0[n] + (1 << (D - 1)) - 1) >> D))
});
let r_0: [R; K] =
core::array::from_fn(|k| R(core::array::from_fn(|n| t[k].0[n] - (r_1[k].0[n] << D))));

(r_1, r_0)
}
Expand Down
11 changes: 6 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
// See <https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.204.ipd.pdf>

// TODO: Roadmap
// 1. Clean up; resolve (mont) math; investigate potential h[last] weakness good/bad sig
// 1. Clean up; resolve (mont) math
// 2. Closer CT inspection -> top level key_gen is vartime, the rest CT outside of rho (? TBC)
// 3. Intensive/extensive pass on documentation
// 4. Revisit/expand unit testing; consider whether to test debug statements: release-vs-test
Expand Down Expand Up @@ -104,7 +104,6 @@ macro_rules! functionality {

const LAMBDA_DIV4: usize = LAMBDA / 4;


// ----- 'EXTERNAL' DATA TYPES -----

/// Correctly sized private key specific to the target security parameter set. <br>
Expand Down Expand Up @@ -231,7 +230,6 @@ macro_rules! functionality {
&self, rng: &mut impl CryptoRngCore, message: &[u8],
) -> Result<Self::Signature, &'static str> {
let esk = ml_dsa::sign_start(ETA, &self.0)?;
//return Ok([0u8; SIG_LEN]);
let sig = ml_dsa::sign_finish::<K, L, LAMBDA_DIV4, SIG_LEN, SK_LEN>(
rng, BETA, GAMMA1, GAMMA2, OMEGA, TAU, &esk, message,
)?;
Expand Down Expand Up @@ -267,6 +265,7 @@ macro_rules! functionality {
}
}


impl Verifier for ExpandedPublicKey {
type Signature = [u8; SIG_LEN];

Expand All @@ -286,7 +285,8 @@ macro_rules! functionality {
type ByteArray = [u8; PK_LEN];

fn try_from_bytes(pk: Self::ByteArray) -> Result<Self, &'static str> {
let _unused = pk_decode::<K, PK_LEN>(&pk).map_err(|_e| "Public key deserialization failed");
let _unused =
pk_decode::<K, PK_LEN>(&pk).map_err(|_e| "Public key deserialization failed");
Ok(PublicKey { 0: pk })
}

Expand All @@ -298,7 +298,8 @@ macro_rules! functionality {
type ByteArray = [u8; SK_LEN];

fn try_from_bytes(sk: Self::ByteArray) -> Result<Self, &'static str> {
let _unused = sk_decode::<K, L, SK_LEN>(ETA, &sk).map_err(|_e| "Private key deserialization failed");
let _unused = sk_decode::<K, L, SK_LEN>(ETA, &sk)
.map_err(|_e| "Private key deserialization failed");
Ok(PrivateKey { 0: sk })
}

Expand Down
Loading

0 comments on commit 63a2c94

Please sign in to comment.