Skip to content

Commit

Permalink
clippy: nist vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
eschorn1 committed Sep 22, 2024
1 parent 00e7745 commit a57b986
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 178 deletions.
7 changes: 4 additions & 3 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
max_width = 100
max_width = 120
hard_tabs = false
tab_spaces = 4
newline_style = "Auto"
indent_style = "Block"
use_small_heuristics = "Default"
fn_call_width = 80
#use_small_heuristics = "Default"
fn_call_width = 100
attr_fn_like_width = 100
struct_lit_width = 60
struct_variant_width = 60
Expand Down Expand Up @@ -49,6 +49,7 @@ match_arm_blocks = true
match_arm_leading_pipes = "Never"
force_multiline_blocks = false
fn_params_layout = "Compressed"
fn_args_layout = "Compressed"
brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
Expand Down
12 changes: 3 additions & 9 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ pub(crate) fn coeff_from_three_bytes<const CTEST: bool>(bbb: [u8; 3]) -> Result<
/// # Errors
/// Returns an error `⊥` on when eta = 4 and b > 8 for rejection sampling. (panics on b > 15)
#[allow(clippy::cast_possible_truncation)] // rem as u8
pub(crate) fn coeff_from_half_byte<const CTEST: bool>(
eta: i32, b: u8,
) -> Result<i32, &'static str> {
pub(crate) fn coeff_from_half_byte<const CTEST: bool>(eta: i32, b: u8) -> Result<i32, &'static str> {
const M5: u32 = ((1u32 << 24) / 5) + 1;
debug_assert!((eta == 2) || (eta == 4), "Alg 9: incorrect eta");
debug_assert!(b < 16, "Alg 9: b out of range"); // Note other cases involving b/eta will fall through to Err()
Expand Down Expand Up @@ -240,9 +238,7 @@ pub(crate) fn bit_unpack(v: &[u8], a: i32, b: i32) -> Result<R, &'static str> {
/// **Input**: A polynomial vector `h ∈ R^k_2` such that at most `ω` of the coefficients in `h` are equal to `1`.
/// Security parameters `ω` (omega) and k must sum to be less than 256. <br>
/// **Output**: A byte string `y` of length `ω + k`.
pub(crate) fn hint_bit_pack<const CTEST: bool, const K: usize>(
omega: i32, h: &[R; K], y_bytes: &mut [u8],
) {
pub(crate) fn hint_bit_pack<const CTEST: bool, const K: usize>(omega: i32, h: &[R; K], y_bytes: &mut [u8]) {
let omega_u = usize::try_from(omega).expect("cannot fail");
debug_assert!((1..256).contains(&(omega_u + K)), "Alg 14: omega+K out of range");
debug_assert_eq!(y_bytes.len(), omega_u + K, "Alg 14: bad output size");
Expand Down Expand Up @@ -301,9 +297,7 @@ pub(crate) fn hint_bit_pack<const CTEST: bool, const K: usize>(
///
/// # Errors
/// Returns an error on invalid input.
pub(crate) fn hint_bit_unpack<const K: usize>(
omega: i32, y_bytes: &[u8],
) -> Result<[R; K], &'static str> {
pub(crate) fn hint_bit_unpack<const K: usize>(omega: i32, y_bytes: &[u8]) -> Result<[R; K], &'static str> {
let omega_u = usize::try_from(omega).expect("Alg 15: omega try_into fail");
debug_assert!((1..256).contains(&(omega_u + K)), "Alg 15: omega+K too large");
debug_assert_eq!(y_bytes.len(), omega_u + K, "Alg 15: bad output size");
Expand Down
35 changes: 8 additions & 27 deletions src/encodings.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// This file implements functionality from FIPS 204 section 8.2 Encodings of ML-DSA Keys and Signatures

use crate::conversion::{
bit_pack, bit_unpack, hint_bit_pack, hint_bit_unpack, simple_bit_pack, simple_bit_unpack,
};
use crate::conversion::{bit_pack, bit_unpack, hint_bit_pack, hint_bit_unpack, simple_bit_pack, simple_bit_unpack};
use crate::helpers::{bit_length, is_in_range};
use crate::types::{R, R0};
use crate::{D, Q};
Expand All @@ -14,9 +12,7 @@ use crate::{D, Q};
///
/// **Input**: `ρ ∈ {0,1}^256`, `t1 ∈ R^k` with coefficients in `[0, 2^{bitlen(q−1)−d}-1]`. <br>
/// **Output**: Public key `pk ∈ B^{32+32·k·(bitlen(q−1)−d)}`.
pub(crate) fn pk_encode<const K: usize, const PK_LEN: usize>(
rho: &[u8; 32], t1: &[R; K],
) -> [u8; PK_LEN] {
pub(crate) fn pk_encode<const K: usize, const PK_LEN: usize>(rho: &[u8; 32], t1: &[R; K]) -> [u8; PK_LEN] {
let blqd = bit_length(Q - 1) - D as usize;
debug_assert!(t1.iter().all(|t| is_in_range(t, 0, (1 << blqd) - 1)), "Alg 16: t1 out of range");
debug_assert_eq!(PK_LEN, 32 + 32 * K * blqd, "Alg 17: bad pk/config size");
Expand All @@ -29,11 +25,7 @@ pub(crate) fn pk_encode<const K: usize, const PK_LEN: usize>(
for i in 0..K {
//
// 3: pk ← pk || SimpleBitPack(t1[i], 2^{bitlen(q−1)−d}-1)
simple_bit_pack(
&t1[i],
(1 << blqd) - 1,
&mut pk[32 + 32 * i * blqd..32 + 32 * (i + 1) * blqd],
);
simple_bit_pack(&t1[i], (1 << blqd) - 1, &mut pk[32 + 32 * i * blqd..32 + 32 * (i + 1) * blqd]);

// 4: end for
}
Expand Down Expand Up @@ -70,8 +62,7 @@ pub(crate) fn pk_decode<const K: usize, const PK_LEN: usize>(
for i in 0..K {
//
// 4: t1[i] ← SimpleBitUnpack(zi, 2^{bitlen(q−1)−d} − 1)) ▷ This is always in the correct range
t1[i] =
simple_bit_unpack(&pk[32 + 32 * i * blqd..32 + 32 * (i + 1) * blqd], (1 << blqd) - 1)?;
t1[i] = simple_bit_unpack(&pk[32 + 32 * i * blqd..32 + 32 * (i + 1) * blqd], (1 << blqd) - 1)?;
//
// 5: end for
}
Expand Down Expand Up @@ -288,12 +279,7 @@ pub(crate) fn sig_encode<
/// # Errors
/// Returns an error when decoded coefficients fall out of range.
#[allow(clippy::type_complexity)]
pub(crate) fn sig_decode<
const K: usize,
const L: usize,
const LAMBDA_DIV4: usize,
const SIG_LEN: usize,
>(
pub(crate) fn sig_decode<const K: usize, const L: usize, const LAMBDA_DIV4: usize, const SIG_LEN: usize>(
gamma1: i32, omega: i32, sigma: &[u8; SIG_LEN],
) -> Result<([u8; LAMBDA_DIV4], [R; L], Option<[R; K]>), &'static str> {
debug_assert_eq!(
Expand Down Expand Up @@ -406,10 +392,7 @@ mod tests {

fn get_vec(max: u32) -> R {
let mut rnd_r = R0; //[0i32; 256];
rnd_r
.0
.iter_mut()
.for_each(|e| *e = rand::random::<i32>().rem_euclid(i32::try_from(max).unwrap()));
rnd_r.0.iter_mut().for_each(|e| *e = rand::random::<i32>().rem_euclid(i32::try_from(max).unwrap()));
rnd_r
}

Expand Down Expand Up @@ -451,10 +434,8 @@ mod tests {
rand::thread_rng().fill_bytes(&mut c_tilde);
let z = [get_vec(2), get_vec(2), get_vec(2), get_vec(2)];
let h = [get_vec(1), get_vec(1), get_vec(1), get_vec(1)];
let sigma =
sig_encode::<false, 4, 4, { 128 / 4 }, 2420>(1 << 17, 80, &c_tilde.clone(), &z, &h);
let (c_test, z_test, h_test) =
sig_decode::<4, 4, { 128 / 4 }, 2420>(1 << 17, 80, &sigma).unwrap();
let sigma = sig_encode::<false, 4, 4, { 128 / 4 }, 2420>(1 << 17, 80, &c_tilde.clone(), &z, &h);
let (c_test, z_test, h_test) = sig_decode::<4, 4, { 128 / 4 }, 2420>(1 << 17, 80, &sigma).unwrap();
assert_eq!(c_tilde[0..8], c_test[0..8]);
assert!(z.iter().zip(z_test.iter()).all(|(a, b)| a.0 == b.0));
assert!(h.iter().zip(h_test.unwrap().iter()).all(|(a, b)| a.0 == b.0));
Expand Down
15 changes: 5 additions & 10 deletions src/hashing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,15 @@ pub(crate) fn rej_bounded_poly<const CTEST: bool>(eta: i32, rhos: &[&[u8]]) -> R
/// **Input**: `ρ ∈ {0,1}^256`. <br>
/// **Output**: Matrix `cap_a_hat`
#[allow(clippy::cast_possible_truncation)] // s and r as u8
pub(crate) fn expand_a<const CTEST: bool, const K: usize, const L: usize>(
rho: &[u8; 32],
) -> [[T; L]; K] {
pub(crate) fn expand_a<const CTEST: bool, const K: usize, const L: usize>(rho: &[u8; 32]) -> [[T; L]; K] {
// 1: for r from 0 to k − 1 do
// 2: for s from 0 to ℓ − 1 do
// 3: A_hat[r, s] ← RejNTTPoly(ρ||IntegerToBits(s, 8) || IntegerToBits(r, 8))
// 4: end for
// 5: end for

let cap_a_hat: [[T; L]; K] = core::array::from_fn(|r| {
core::array::from_fn(|s| rej_ntt_poly::<CTEST>(&[&rho[..], &[s as u8], &[r as u8]]))
});
let cap_a_hat: [[T; L]; K] =
core::array::from_fn(|r| core::array::from_fn(|s| rej_ntt_poly::<CTEST>(&[&rho[..], &[s as u8], &[r as u8]])));
cap_a_hat
}

Expand All @@ -285,14 +282,12 @@ pub(crate) fn expand_s<const CTEST: bool, const K: usize, const L: usize>(
// 1: for r from 0 to ℓ − 1 do
// 2: s1[r] ← RejBoundedPoly(ρ || IntegerToBits(r, 16))
// 3: end for
let s1: [R; L] =
core::array::from_fn(|r| rej_bounded_poly::<CTEST>(eta, &[rho, &[r as u8], &[0]]));
let s1: [R; L] = core::array::from_fn(|r| rej_bounded_poly::<CTEST>(eta, &[rho, &[r as u8], &[0]]));

// 4: for r from 0 to k − 1 do
// 5: s2[r] ← RejBoundedPoly(ρ || IntegerToBits(r + ℓ, 16))
// 6: end for
let s2: [R; K] =
core::array::from_fn(|r| rej_bounded_poly::<CTEST>(eta, &[rho, &[(r + L) as u8], &[0]]));
let s2: [R; K] = core::array::from_fn(|r| rej_bounded_poly::<CTEST>(eta, &[rho, &[(r + L) as u8], &[0]]));

// 7: return (s_1 , s_2)
debug_assert!(s1.iter().all(|r| is_in_range(r, eta, eta)), "Alg 27: s1 out of range");
Expand Down
8 changes: 2 additions & 6 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ pub(crate) fn center_mod(m: i32) -> i32 {

/// Matrix by vector multiplication; e.g., fips 203 top of page 10, first row: `w_hat` = `A_hat` mul `u_hat`
#[must_use]
pub(crate) fn mat_vec_mul<const K: usize, const L: usize>(
a_hat: &[[T; L]; K], u_hat: &[T; L],
) -> [T; K] {
pub(crate) fn mat_vec_mul<const K: usize, const L: usize>(a_hat: &[[T; L]; K], u_hat: &[T; L]) -> [T; K] {
let mut w_hat = [T0; K];
let u_hat_mont = to_mont(u_hat);
for i in 0..K {
Expand All @@ -124,9 +122,7 @@ pub(crate) fn vec_add<const K: usize>(vec_a: &[R; K], vec_b: &[R; K]) -> [R; K]

#[allow(clippy::cast_possible_truncation)] // as i32
pub(crate) fn to_mont<const L: usize>(vec_a: &[T; L]) -> [T; L] {
core::array::from_fn(|l| {
T(core::array::from_fn(|n| partial_reduce64(i64::from(vec_a[l].0[n]) << 32)))
})
core::array::from_fn(|l| T(core::array::from_fn(|n| partial_reduce64(i64::from(vec_a[l].0[n]) << 32))))
}


Expand Down
7 changes: 2 additions & 5 deletions src/high_low.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ pub(crate) fn power2round<const K: usize>(r: &[R; K]) -> ([R; K], [R; K]) {
r.iter().flat_map(|row| row.0).all(|element| (0..Q).contains(&element)),
"power2round input"
);
let r_1: [R; K] = core::array::from_fn(|k| {
R(core::array::from_fn(|n| (r[k].0[n] + (1 << (D - 1)) - 1) >> D))
});
let r_0: [R; K] =
core::array::from_fn(|k| R(core::array::from_fn(|n| r[k].0[n] - (r_1[k].0[n] << D))));
let r_1: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| (r[k].0[n] + (1 << (D - 1)) - 1) >> D)));
let r_0: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| r[k].0[n] - (r_1[k].0[n] << D))));

debug_assert!(
{
Expand Down
Loading

0 comments on commit a57b986

Please sign in to comment.