Skip to content

Commit

Permalink
p384: add basic tests for Scalar arithmetic (#493)
Browse files Browse the repository at this point in the history
Currently not passing and flagged with `#[ignore]`.

Also moves `montgomery_reduce` into the `Scalar` type as a static method
  • Loading branch information
tarcieri authored Dec 19, 2021
1 parent 09ae780 commit 6fcf10f
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 82 deletions.
2 changes: 1 addition & 1 deletion p384/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ sha2 = { version = "0.9", optional = true, default-features = false }
hex-literal = "0.3"

[features]
default = ["arithmetic", "pkcs8", "std"]
default = ["pkcs8", "std"]
arithmetic = ["elliptic-curve/arithmetic"]
jwk = ["elliptic-curve/jwk"]
pem = ["elliptic-curve/pem", "pkcs8"]
Expand Down
209 changes: 128 additions & 81 deletions p384/src/arithmetic/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl Scalar {
let (r10, carry) = r10.mac(limbs[5], limbs[5], carry);
let (r11, _) = r11.adc(Limb::ZERO, carry);

mont_reduce(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11)
Self::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11)
}

/// Compute scalar inversion.
Expand Down Expand Up @@ -1243,6 +1243,84 @@ impl Scalar {
let sqrt = t513 * t318;
CtOption::new(sqrt, sqrt.square().ct_eq(self))
}

/// Montgomery reduction.
#[cfg(target_pointer_width = "64")]
#[allow(clippy::too_many_arguments)]
#[inline(always)]
fn montgomery_reduce(
r0: Limb,
r1: Limb,
r2: Limb,
r3: Limb,
r4: Limb,
r5: Limb,
r6: Limb,
r7: Limb,
r8: Limb,
r9: Limb,
r10: Limb,
r11: Limb,
) -> Self {
// NOTE: generated by `ff_derive`
let modulus = MODULUS.limbs();

let k = r0.wrapping_mul(Limb(INV));
let (_, carry) = r0.mac(k, modulus[0], Limb::ZERO);
let (r1, carry) = r1.mac(k, modulus[1], carry);
let (r2, carry) = r2.mac(k, modulus[2], carry);
let (r3, carry) = r3.mac(k, modulus[3], carry);
let (r4, carry) = r4.mac(k, modulus[4], carry);
let (r5, carry) = r5.mac(k, modulus[5], carry);
let (r6, carry2) = r6.adc(Limb::ZERO, carry);

let k = r1.wrapping_mul(Limb(INV));
let (_, carry) = r1.mac(k, modulus[0], Limb::ZERO);
let (r2, carry) = r2.mac(k, modulus[1], carry);
let (r3, carry) = r3.mac(k, modulus[2], carry);
let (r4, carry) = r4.mac(k, modulus[3], carry);
let (r5, carry) = r5.mac(k, modulus[4], carry);
let (r6, carry) = r6.mac(k, modulus[5], carry);
let (r7, carry2) = r7.adc(carry2, carry);

let k = r2.wrapping_mul(Limb(INV));
let (_, carry) = r2.mac(k, modulus[0], Limb::ZERO);
let (r3, carry) = r3.mac(k, modulus[1], carry);
let (r4, carry) = r4.mac(k, modulus[2], carry);
let (r5, carry) = r5.mac(k, modulus[3], carry);
let (r6, carry) = r6.mac(k, modulus[4], carry);
let (r7, carry) = r7.mac(k, modulus[5], carry);
let (r8, carry2) = r8.adc(carry2, carry);

let k = r3.wrapping_mul(Limb(INV));
let (_, carry) = r3.mac(k, modulus[0], Limb::ZERO);
let (r4, carry) = r4.mac(k, modulus[1], carry);
let (r5, carry) = r5.mac(k, modulus[2], carry);
let (r6, carry) = r6.mac(k, modulus[3], carry);
let (r7, carry) = r7.mac(k, modulus[4], carry);
let (r8, carry) = r8.mac(k, modulus[5], carry);
let (r9, carry2) = r9.adc(carry2, carry);

let k = r4.wrapping_mul(Limb(INV));
let (_, carry) = r4.mac(k, modulus[0], Limb::ZERO);
let (r5, carry) = r5.mac(k, modulus[1], carry);
let (r6, carry) = r6.mac(k, modulus[2], carry);
let (r7, carry) = r7.mac(k, modulus[3], carry);
let (r8, carry) = r8.mac(k, modulus[4], carry);
let (r9, carry) = r9.mac(k, modulus[5], carry);
let (r10, carry2) = r10.adc(carry2, carry);

let k = r5.wrapping_mul(Limb(INV));
let (_, carry) = r5.mac(k, modulus[0], Limb::ZERO);
let (r6, carry) = r6.mac(k, modulus[1], carry);
let (r7, carry) = r7.mac(k, modulus[2], carry);
let (r8, carry) = r8.mac(k, modulus[3], carry);
let (r9, carry) = r9.mac(k, modulus[4], carry);
let (r10, carry) = r10.mac(k, modulus[5], carry);
let (r11, _) = r11.adc(carry2, carry);

Self::from_uint_reduced(U384::new([r6, r7, r8, r9, r10, r11]))
}
}

impl From<u64> for Scalar {
Expand Down Expand Up @@ -1450,7 +1528,7 @@ impl MulAssign<&Scalar> for Scalar {
let (r10, carry) = r10.mac(a[5], b[5], carry);
let r11 = carry;

mont_reduce(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11);
Self::montgomery_reduce(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11);
}
}

Expand All @@ -1470,89 +1548,11 @@ impl Reduce<U384> for Scalar {
}
}

/// Montgomery reduction.
#[cfg(target_pointer_width = "64")]
#[allow(clippy::too_many_arguments)]
#[inline(always)]
fn mont_reduce(
r0: Limb,
r1: Limb,
r2: Limb,
r3: Limb,
r4: Limb,
r5: Limb,
r6: Limb,
r7: Limb,
r8: Limb,
r9: Limb,
r10: Limb,
r11: Limb,
) -> Scalar {
// NOTE: generated by `ff_derive`
let modulus = MODULUS.limbs();

let k = r0.wrapping_mul(Limb(INV));
let (_, carry) = r0.mac(k, modulus[0], Limb::ZERO);
let (r1, carry) = r1.mac(k, modulus[1], carry);
let (r2, carry) = r2.mac(k, modulus[2], carry);
let (r3, carry) = r3.mac(k, modulus[3], carry);
let (r4, carry) = r4.mac(k, modulus[4], carry);
let (r5, carry) = r5.mac(k, modulus[5], carry);
let (r6, carry2) = r6.adc(Limb::ZERO, carry);

let k = r1.wrapping_mul(Limb(INV));
let (_, carry) = r1.mac(k, modulus[0], Limb::ZERO);
let (r2, carry) = r2.mac(k, modulus[1], carry);
let (r3, carry) = r3.mac(k, modulus[2], carry);
let (r4, carry) = r4.mac(k, modulus[3], carry);
let (r5, carry) = r5.mac(k, modulus[4], carry);
let (r6, carry) = r6.mac(k, modulus[5], carry);
let (r7, carry2) = r7.adc(carry2, carry);

let k = r2.wrapping_mul(Limb(INV));
let (_, carry) = r2.mac(k, modulus[0], Limb::ZERO);
let (r3, carry) = r3.mac(k, modulus[1], carry);
let (r4, carry) = r4.mac(k, modulus[2], carry);
let (r5, carry) = r5.mac(k, modulus[3], carry);
let (r6, carry) = r6.mac(k, modulus[4], carry);
let (r7, carry) = r7.mac(k, modulus[5], carry);
let (r8, carry2) = r8.adc(carry2, carry);

let k = r3.wrapping_mul(Limb(INV));
let (_, carry) = r3.mac(k, modulus[0], Limb::ZERO);
let (r4, carry) = r4.mac(k, modulus[1], carry);
let (r5, carry) = r5.mac(k, modulus[2], carry);
let (r6, carry) = r6.mac(k, modulus[3], carry);
let (r7, carry) = r7.mac(k, modulus[4], carry);
let (r8, carry) = r8.mac(k, modulus[5], carry);
let (r9, carry2) = r9.adc(carry2, carry);

let k = r4.wrapping_mul(Limb(INV));
let (_, carry) = r4.mac(k, modulus[0], Limb::ZERO);
let (r5, carry) = r5.mac(k, modulus[1], carry);
let (r6, carry) = r6.mac(k, modulus[2], carry);
let (r7, carry) = r7.mac(k, modulus[3], carry);
let (r8, carry) = r8.mac(k, modulus[4], carry);
let (r9, carry) = r9.mac(k, modulus[5], carry);
let (r10, carry2) = r10.adc(carry2, carry);

let k = r5.wrapping_mul(Limb(INV));
let (_, carry) = r5.mac(k, modulus[0], Limb::ZERO);
let (r6, carry) = r6.mac(k, modulus[1], carry);
let (r7, carry) = r7.mac(k, modulus[2], carry);
let (r8, carry) = r8.mac(k, modulus[3], carry);
let (r9, carry) = r9.mac(k, modulus[4], carry);
let (r10, carry) = r10.mac(k, modulus[5], carry);
let (r11, _) = r11.adc(carry2, carry);

Scalar::from_uint_reduced(U384::new([r6, r7, r8, r9, r10, r11]))
}

#[cfg(test)]
mod tests {
use super::Scalar;
use crate::FieldBytes;
use elliptic_curve::ff::PrimeField;
use elliptic_curve::ff::{Field, PrimeField};

#[test]
fn from_to_bytes_roundtrip() {
Expand All @@ -1563,4 +1563,51 @@ mod tests {
let scalar = Scalar::from_repr(bytes).unwrap();
assert_eq!(bytes, scalar.to_bytes());
}

/// Basic tests that multiplication works.
#[test]
#[ignore]
fn multiply() {
let one = Scalar::one();
let two = one + &one;
let three = two + &one;
let six = three + &three;
assert_eq!(six, two * &three);

let minus_two = -two;
let minus_three = -three;
assert_eq!(two, -minus_two);

assert_eq!(minus_three * &minus_two, minus_two * &minus_three);
assert_eq!(six, minus_two * &minus_three);
}

/// Basic tests that scalar inversion works.
#[test]
#[ignore]
fn invert() {
let one = Scalar::one();
let three = one + &one + &one;
let inv_three = three.invert().unwrap();
// println!("1/3 = {:x?}", &inv_three);
assert_eq!(three * &inv_three, one);

let minus_three = -three;
// println!("-3 = {:x?}", &minus_three);
let inv_minus_three = minus_three.invert().unwrap();
assert_eq!(inv_minus_three, -inv_three);
// println!("-1/3 = {:x?}", &inv_minus_three);
assert_eq!(three * &inv_minus_three, -one);
}

/// Basic tests that sqrt works.
#[test]
#[ignore]
fn sqrt() {
for &n in &[1u64, 4, 9, 16, 25, 36, 49, 64] {
let scalar = Scalar::from(n);
let sqrt = scalar.sqrt().unwrap();
assert_eq!(sqrt.square(), scalar);
}
}
}

0 comments on commit 6fcf10f

Please sign in to comment.