Skip to content

Commit

Permalink
added SerDes to expanded public/private keys
Browse files Browse the repository at this point in the history
  • Loading branch information
eschorn1 committed Oct 15, 2024
1 parent ec25a7a commit 377ae34
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
75 changes: 71 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,18 @@ macro_rules! functionality {

// ----- SERIALIZATION AND DESERIALIZATION ---

impl SerDes for PrivateKey {
type ByteArray = [u8; SK_LEN];

fn try_from_bytes(sk: Self::ByteArray) -> Result<Self, &'static str> {
let _unused = sk_decode::<K, L, SK_LEN>(ETA, &sk).map_err(|_e| "Private key deserialization failed");
Ok(PrivateKey { 0: sk })
}

fn into_bytes(self) -> Self::ByteArray { self.0 }
}


impl SerDes for PublicKey {
type ByteArray = [u8; PK_LEN];

Expand All @@ -409,18 +421,67 @@ macro_rules! functionality {
}


impl SerDes for PrivateKey {
impl SerDes for ExpandedPrivateKey {
type ByteArray = [u8; SK_LEN];

fn try_from_bytes(sk: Self::ByteArray) -> Result<Self, &'static str> {
let _unused = sk_decode::<K, L, SK_LEN>(ETA, &sk).map_err(|_e| "Private key deserialization failed");
Ok(PrivateKey { 0: sk })
let sk = PrivateKey { 0: sk };
let esk = ml_dsa::sign_start::<CTEST, K, L, SK_LEN>(ETA, &sk.0)?;
Ok(esk)
}

fn into_bytes(self) -> Self::ByteArray { self.0 }
#[allow(clippy::cast_lossless)]
fn into_bytes(self) -> Self::ByteArray {
use crate::types::{R, T};
use crate::helpers::mont_reduce;
use crate::helpers::full_reduce32;
use crate::ntt::inv_ntt;

// TODO: polish needed
let ExpandedPrivateKey {rho, cap_k, tr, s_hat_1_mont, s_hat_2_mont, t_hat_0_mont, ..} = &self;

let s_1: [R; L] = inv_ntt(&core::array::from_fn(|l| T(core::array::from_fn(|n| full_reduce32(mont_reduce(s_hat_1_mont[l].0[n] as i64))))));
let s_1: [R; L] = core::array::from_fn(|l| R(core::array::from_fn(|n| if s_1[l].0[n] > (Q >> 2) {s_1[l].0[n] - Q} else {s_1[l].0[n]})));

let s_2: [R; K] = inv_ntt(&core::array::from_fn(|k| T(core::array::from_fn(|n| full_reduce32(mont_reduce(s_hat_2_mont[k].0[n] as i64))))));
let s_2: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| if s_2[k].0[n] > (Q >> 2) {s_2[k].0[n] - Q} else {s_2[k].0[n]})));

let t_0: [R; K] = inv_ntt(&core::array::from_fn(|k| T(core::array::from_fn(|n| full_reduce32(mont_reduce(t_hat_0_mont[k].0[n] as i64))))));
let t_0: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| if t_0[k].0[n] > (Q / 2) {t_0[k].0[n] - Q} else {t_0[k].0[n]})));

let sk = crate::encodings::sk_encode::<K, L, SK_LEN>(ETA, rho, cap_k, tr, &s_1, &s_2, &t_0);
sk
}
}


impl SerDes for ExpandedPublicKey {
type ByteArray = [u8; PK_LEN];

fn try_from_bytes(pk: Self::ByteArray) -> Result<Self, &'static str> {
let epk = ml_dsa::verify_start(&pk)?;
Ok(epk)

}

#[allow(clippy::cast_lossless)]
fn into_bytes(self) -> Self::ByteArray {
use crate::types::{R, T};
use crate::helpers::mont_reduce;
use crate::helpers::full_reduce32;
use crate::ntt::inv_ntt;
use crate::D;

let ExpandedPublicKey {rho, cap_a_hat, tr, t1_d2_hat_mont} = &self;
let (_, _, _, _) = (rho, cap_a_hat, tr, t1_d2_hat_mont);
let t1_d2: [R; K] = inv_ntt(&core::array::from_fn(|k| T(core::array::from_fn(|n| full_reduce32(mont_reduce(t1_d2_hat_mont[k].0[n] as i64))))));
let t1: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| t1_d2[k].0[n] >> D)));
let pk = crate::encodings::pk_encode(rho, &t1);
pk
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -443,7 +504,13 @@ macro_rules! functionality {
let v = pk.hash_verify(&message1, &sig, &[], &ph);
assert!(v);
}
assert_eq!(pk.into_bytes(), sk.get_public_key().into_bytes());
assert_eq!(pk.clone().into_bytes(), sk.get_public_key().into_bytes());

let esk = KG::gen_expanded_private(&sk);
assert_eq!(sk.into_bytes(), esk.unwrap().into_bytes());

let epk = KG::gen_expanded_public(&pk);
assert_eq!(pk.into_bytes(), epk.unwrap().into_bytes());
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/ml_dsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub(crate) fn sign_start<const CTEST: bool, const K: usize, const L: usize, cons
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);

Ok(ExpandedPrivateKey {
rho: *rho,
cap_k: *cap_k,
tr: *tr,
s_hat_1_mont,
Expand Down Expand Up @@ -121,6 +122,7 @@ pub(crate) fn sign_finish<

// Extract from sign_start()
let ExpandedPrivateKey {
rho: _,
cap_k,
tr,
s_hat_1_mont,
Expand Down Expand Up @@ -323,7 +325,7 @@ pub(crate) fn verify_start<const K: usize, const L: usize, const PK_LEN: usize>(
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
}));

Ok(ExpandedPublicKey { cap_a_hat, tr, t1_d2_hat_mont })
Ok(ExpandedPublicKey { rho: *rho, cap_a_hat, tr, t1_d2_hat_mont })
}


Expand All @@ -341,7 +343,7 @@ pub(crate) fn verify_finish<
m: &[u8], sig: &[u8; SIG_LEN], ctx: &[u8], oid: &[u8], phm: &[u8], nist: bool,
) -> Result<bool, &'static str> {
//
let ExpandedPublicKey { cap_a_hat, tr, t1_d2_hat_mont } = epk;
let ExpandedPublicKey { rho: _, cap_a_hat, tr, t1_d2_hat_mont } = epk;

// 1: (ρ, t_1) ← pkDecode(pk)
// --> calculated in verify_start()
Expand Down
2 changes: 2 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct PrivateKey<const SK_LEN: usize>(pub(crate) [u8; SK_LEN]);
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
#[repr(align(8))]
pub struct ExpandedPrivateKey<const K: usize, const L: usize> {
pub(crate) rho: [u8; 32],
pub(crate) cap_k: [u8; 32],
pub(crate) tr: [u8; 64],
pub(crate) s_hat_1_mont: [T; L],
Expand All @@ -47,6 +48,7 @@ pub struct PublicKey<const PK_LEN: usize>(pub(crate) [u8; PK_LEN]);
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
#[repr(align(8))]
pub struct ExpandedPublicKey<const K: usize, const L: usize> {
pub(crate) rho: [u8; 32],
pub(crate) cap_a_hat: [[T; L]; K],
pub(crate) tr: [u8; 64],
pub(crate) t1_d2_hat_mont: [T; K],
Expand Down

0 comments on commit 377ae34

Please sign in to comment.