diff --git a/Primitive/Asymmetric/Signature/ML_DSA/ML_DSA.cry b/Primitive/Asymmetric/Signature/ML_DSA/ML_DSA.cry index 91dca073..0a13dcc7 100644 --- a/Primitive/Asymmetric/Signature/ML_DSA/ML_DSA.cry +++ b/Primitive/Asymmetric/Signature/ML_DSA/ML_DSA.cry @@ -36,7 +36,7 @@ module Primitive::Asymmetric::Signature::ML_DSA::ML_DSA where import interface Primitive::Asymmetric::Signature::ML_DSA::Parameters as P -import Primitive::Asymmetric::Signature::ML_DSA::Specification { interface P } +import Primitive::Asymmetric::Signature::ML_DSA::OptimizedSpecification { interface P } /** * Generate a public-private key pair. @@ -98,7 +98,7 @@ Verify pk M sigma ctx // We expose the internal functions for testing purposes only. submodule TestAPI where - import Primitive::Asymmetric::Signature::ML_DSA::Specification { interface P } as Spec + import Primitive::Asymmetric::Signature::ML_DSA::OptimizedSpecification { interface P } as Spec KeyGen_internal = Spec::KeyGen_internal Sign_internal = Spec::Sign_internal diff --git a/Primitive/Asymmetric/Signature/ML_DSA/OptimizedSpecification.cry b/Primitive/Asymmetric/Signature/ML_DSA/OptimizedSpecification.cry new file mode 100644 index 00000000..32f14fcf --- /dev/null +++ b/Primitive/Asymmetric/Signature/ML_DSA/OptimizedSpecification.cry @@ -0,0 +1,498 @@ +/** + * Optimized implementation of the ML-DSA (CRYSTALS-Dilithium) signature scheme. + * + * This implementation deviates from the specification in favor of performance. + * + * @copyright Galois Inc + * @author Marios Georgiou + */ +module Primitive::Asymmetric::Signature::ML_DSA::OptimizedSpecification where + +import interface Primitive::Asymmetric::Signature::ML_DSA::Parameters as P +import Primitive::Asymmetric::Signature::ML_DSA::Specification { interface P } as Spec + +type q = P::q +type ω = P::ω +type k = P::k +type ell = P::ell +type η = P::η +type λ = P::λ +type γ1 = P::γ1 +type γ2 = P::γ2 +type τ = P::τ + +type Byte = Spec::Byte + +type Tq = Spec::Tq +type R = Spec::R +type R2 = Spec::R2 +type Rq = Spec::Rq + +modPlusMinus : {α} (fin α) => Z q -> Integer +modPlusMinus = Spec::modPlusMinus`{α} + +infNormRq = Spec::infNormRq + +infNormR = Spec::infNormR + +castToRq = Spec::castToRq + +NTT_Vec = Spec::NTT_Vec + +NTTInv_Vec = Spec::NTTInv_Vec + +H = Spec::H + +HBits = Spec::HBits + +G = Spec::G + +ζ = Spec::ζ + +type d = Spec::d + +β = Spec::β + +type PublicKey = Spec::PublicKey + +type PrivateKey = Spec::PrivateKey + +type Signature = Spec::Signature + +/** + * This is almost identical to `Spec`; it's duplicated to pull in the optimized + * dependencies in this file but H`{64} is replaced by H to avoid compilation + * errors. + */ +KeyGen_internal : [32]Byte -> (PublicKey, PrivateKey) +KeyGen_internal ξ = (pk, sk) where + // Step 1. + (ρ # ρ' # K) = H (ξ # IntegerToBytes`{1} `k # IntegerToBytes`{1} `ell) + + // Step 3. + A_hat = ExpandA ρ + // Step 4. + (s1, s2) = ExpandS ρ' + + // Explicitly typecast vectors in `R` to `Rq`. + s1' = castToRq s1 + s2' = castToRq s2 + + // Step 5. + t = NTTInv_Vec (A_hat ∘∘ NTT_Vec s1') + s2' + // Step 6. + (t1, t0) = Power2Round t + + // Step 8. + pk = pkEncode ρ t1 + // Step 9. + tr = H pk + // Step 10. + sk = skEncode ρ K tr s1 s2 t0 + +/** + * ```repl + * :set tests=3 + * :check KeyGen_internalEquivalence + * ``` + */ +KeyGen_internalEquivalence : [32]Byte -> Bit +property KeyGen_internalEquivalence ξ = Spec::KeyGen_internal ξ == KeyGen_internal ξ + +Sign_internal = Spec::Sign_internal + +Verify_internal = Spec::Verify_internal + +IntegerToBits : {α} (fin α, α > 0) => Integer -> [α] +IntegerToBits x = reverse (fromInteger x) + +/** + * ```repl + * :check IntegerToBitsEquivalence`{44} + * ``` + */ +IntegerToBitsEquivalence : {α} (fin α, α > 0) => Integer -> Bit +property IntegerToBitsEquivalence x = Spec::IntegerToBits`{α} x == IntegerToBits`{α} x + +BitsToInteger y = toInteger (reverse y) + +/** + * ```repl + * :check BitsToIntegerEquivalence`{44} + * :exhaust BitsToIntegerEquivalence`{10} + * ``` + */ +BitsToIntegerEquivalence : {α} (fin α, α > 0) => [α] -> Bit +property BitsToIntegerEquivalence x = Spec::BitsToInteger x == BitsToInteger x + +IntegerToBytes : {α} (fin α, α > 0) => Integer -> [α]Byte +IntegerToBytes x = reverse (split (fromInteger x)) + +/** + * ```repl + * :check IntegerToBytesEquivalence`{44} + * ``` + */ +IntegerToBytesEquivalence : {α} (fin α, α > 0) => Integer -> Bit +property IntegerToBytesEquivalence x = Spec::IntegerToBytes`{α} x == IntegerToBytes`{α} x + +BitsToBytes : {α} (fin α) => [α]Bit -> [α /^ 8]Byte +BitsToBytes y = map reverse (split (y # zero)) + +/** + * ```repl + * :prove BitsToBytesEquivalence`{320 * 8} + * :prove BitsToBytesEquivalence`{32 * 44 * 8} + * ``` + */ +BitsToBytesEquivalence : {α} (fin α) => [α]Bit -> Bit +property BitsToBytesEquivalence x = Spec::BitsToBytes x == BitsToBytes x + +BytesToBits : {α} (fin α) => [α]Byte -> [8 * α]Bit +BytesToBits z = join (map reverse z) + +/** + * ```repl + * :prove BytesToBitsEquivalence`{320} + * :prove BytesToBitsEquivalence`{32 * 44} + * ``` + */ +BytesToBitsEquivalence : {α} (fin α) => [α]Byte -> Bit +property BytesToBitsEquivalence x = Spec::BytesToBits x == BytesToBits x + +B2B2BInverts = Spec::B2B2BInverts + +CoeffFromThreeBytes = Spec::CoeffFromThreeBytes + +CoeffFromHalfByte = Spec::CoeffFromHalfByte + +SimpleBitPack : {b} (fin b, width b > 0) => R -> [32 * width b]Byte +SimpleBitPack w = BitsToBytes (join (map IntegerToBits`{width b} w)) + +/** + * ```repl + * :check SimpleBitPackEquivalence`{10} + * ``` + */ +SimpleBitPackEquivalence : {b} (fin b, width b > 0) => R -> Bit +property SimpleBitPackEquivalence w = Spec::SimpleBitPack`{b} w == SimpleBitPack`{b} w + +BitPack : {a, b} (fin a, fin b, width (a + b) > 0) => + R -> [32 * width (a + b)]Byte +BitPack w = BitsToBytes (join (map (\x -> IntegerToBits`{width (a + b)} (`b - x)) w)) + +/** + * ```repl + * :check BitPackEquivalence`{10, 10} + * ``` + */ +BitPackEquivalence : {a, b} (fin a, fin b, width (a + b) > 0) => R -> Bit +property BitPackEquivalence w = Spec::BitPack`{a, b} w == BitPack`{a, b} w + +SimpleBitUnpack : {b} (fin b, width b > 0) => [32 * width b]Byte -> R +SimpleBitUnpack v = map BitsToInteger (split (BytesToBits v)) + +/** + * ```repl + * :check SimpleBitUnpackEquivalence`{10} + * ``` + */ +SimpleBitUnpackEquivalence : {b} (fin b, width b > 0) => [32 * width b]Byte -> Bit +property SimpleBitUnpackEquivalence v = Spec::SimpleBitUnpack`{b} v == SimpleBitUnpack`{b} v + +BitUnpack : {a, b} (fin a, fin b, width (a + b) > 0) => + [32 * width (a + b)]Byte -> R +BitUnpack v = map (\x -> `b - BitsToInteger x) (split (BytesToBits v)) + +/** + * ```repl + * :check BitUnpackEquivalence`{10, 10} + * ``` + */ +BitUnpackEquivalence : {a, b} (fin a, fin b, width (a + b) > 0) => [32 * width (a + b)]Byte -> Bit +property BitUnpackEquivalence v = Spec::BitUnpack`{a, b} v == BitUnpack`{a, b} v + +HintBitPack = Spec::HintBitPack + +HintBitUnpack = Spec::HintBitUnpack + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +pkEncode : [32]Byte -> [k]R -> PublicKey +pkEncode ρ t1 = pk where + pk = ρ # join [SimpleBitPack`{2 ^^ (width (q - 1) - d) - 1} (t1@i) | i <- [0..k-1]] + +/** + * ```repl + * :check pkEncodeEquivalence + * ``` + */ +pkEncodeEquivalence : [32]Byte -> [k]R -> Bit +property pkEncodeEquivalence ρ t1 = Spec::pkEncode ρ t1 == pkEncode ρ t1 + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +pkDecode : PublicKey -> ([32]Byte, [k]R) +pkDecode pk = (ρ, t1) where + // Step 1. We split off the single `ρ` byte, then separate the remaining + // bytes into the `k` components as described. + (ρ # zBytes) = pk + z = split zBytes + // Steps 2 - 4. + t1 = [SimpleBitUnpack`{2 ^^ (width (q - 1) - d) - 1} (z@i) | i <- [0..k-1]] + +/** + * ```repl + * :check pkDecodeEquivalence + * ``` + */ +pkDecodeEquivalence : PublicKey -> Bit +property pkDecodeEquivalence pk = Spec::pkDecode pk == pkDecode pk + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +skEncode : [32]Byte -> [32]Byte -> [64]Byte -> [ell]R -> [k]R -> [k]R + -> PrivateKey +skEncode ρ K tr s1 s2 t0 = sk9 where + // Note: `sk#` indicates the value of `sk` at Step `#`. + // Step 1. + sk1 = ρ # K # tr + // Steps 2 - 4. + sk3 = sk1 # join [BitPack`{η, η} (s1@i) | i <- [0..ell-1]] + // Steps 5 - 7. + sk6 = sk3 # join [BitPack`{η, η} (s2@i) | i <- [0..k-1]] + // Steps 8 - 10. + sk9 = sk6 # + join [BitPack`{2^^(d - 1) - 1, 2^^(d - 1)} (t0@i) | i <- [0..k-1]] + +/** + * ```repl + * :check skEncodeEquivalence + * ``` + */ +skEncodeEquivalence : [32]Byte -> [32]Byte -> [64]Byte -> [ell]R -> [k]R -> [k]R -> Bit +property skEncodeEquivalence ρ K tr s1 s2 t0 = Spec::skEncode ρ K tr s1 s2 t0 == skEncode ρ K tr s1 s2 t0 + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +skDecode : PrivateKey -> ([32]Byte, [32]Byte, [64]Byte, [ell]R, [k]R, [k]R) +skDecode sk = (ρ, K, tr, s1, s2, t0) where + // Step 1. We split off the six components, then further separate `y`, `z`, + // and `w` into their two dimensions. + (ρ # K # tr # yBytes # zBytes # wBytes) = sk + y = split`{ell} yBytes + z = split`{k} zBytes + w = split`{k} wBytes + + // Steps 2 - 4. + s1 = [BitUnpack`{η, η} (y@i) | i <- [0..ell-1]] + // Steps 5 - 7. + s2 = [BitUnpack`{η, η} (z@i) | i <- [0..k-1]] + // Steps 8 - 10. + t0 = [BitUnpack`{2^^(d - 1) - 1, 2^^(d - 1)} (w@i) | i <- [0..k-1]] + +/** + * ```repl + * :check skDecodeEquivalence + * ``` + */ +skDecodeEquivalence : PrivateKey -> Bit +property skDecodeEquivalence sk = Spec::skDecode sk == skDecode sk + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +sigEncode : [λ / 4]Byte -> [ell]R -> [k]R2 -> Signature +sigEncode c_til z h = σ where + // Note that `σ#` indicates the value of `σ` at Step `#`. + // Step 1. + σ1 = c_til + // Step 2 - 4. + σ3 = σ1 # join [BitPack`{γ1 - 1, γ1} (z@i) | i <- [0..ell-1]] + // Step 5. + σ = σ3 # HintBitPack h + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +sigDecode : Signature -> ([λ / 4]Byte, [ell]R, Option ([k]R2)) +sigDecode σ = (c_til, z, h) where + // Step 1. We separate into bytes, then further split `x` into its two + // dimensions. + (c_til # xBytes # y) = σ + x = split`{ell} xBytes + + // Step 2 - 4. + z = [BitUnpack`{γ1 - 1, γ1} (x@i) | i <- [0..ell-1]] + // Step 5. + h = HintBitUnpack y + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +w1Encode : [k]R -> [32 * k * width ((q - 1) / (2 * γ2) - 1)]Byte +w1Encode w1 = w1_til where + w1_til = join + [SimpleBitPack`{(q - 1) / (2 * γ2) - 1} (w1@i) | i <- [0..k-1]] + +/** + * ```repl + * :check w1EncodeEquivalence + * ``` + */ +w1EncodeEquivalence : [k]R -> Bit +property w1EncodeEquivalence w1 = Spec::w1Encode w1 == w1Encode w1 + +/** + * This is almost identical to `Spec`; it's duplicated to pull in the optimized + * dependencies in this file but H`{inf} is replaced by H to avoid compilation + * errors. + */ +SampleInBall : [λ / 4]Byte -> R +SampleInBall ρ = cFinal where + // Step 1. + c0 = zero + // Steps 2 - 3. + ctx_0 = H ρ + // Step 4. + ((s : [8]Byte) # ctx_1) = ctx_0 + // Step 5. + h = BytesToBits s + + // Steps 7 - 10. Uses recursion instead of a loop to sample bytes from the + // hash stream, returning the first one that's in the range `[0, i]`. + sample : [inf]Byte -> Byte -> (Byte, [inf]Byte) + sample ([j] # ctx) i = + if j > i then + sample ctx i + else (j, ctx) + + // Steps 6 - 13. Computes the value of `c` and the updated `ctx` at each + // iteration of the loop. + cAndCtx = [(c0, ctx_1)] # [(c'', ctx') where + // Steps 7 - 10. + (j, ctx') = sample ctx (fromInteger i) + // Step 11. + c' = update c i (c@j) + // Step 12. In Cryptol, we need to manually convert the exponent + // from a `Bit` to a numeric type. + hiτ = if (h @ (i + `τ - 256)) then 1 else 0 : Integer + c'' = update c' j ((-1)^^hiτ) + + | i <- [256 - τ..255] + | (c, ctx) <- cAndCtx] + + (cFinal, _) = cAndCtx ! 0 + +/** + * ```repl + * :set tests=3 + * :check SampleInBallEquivalence + * ``` + */ +SampleInBallEquivalence : [λ / 4]Byte -> Bit +property SampleInBallEquivalence ρ = Spec::SampleInBall ρ == SampleInBall ρ + +RejNTTPoly = Spec::RejNTTPoly + +RejBoundedPoly = Spec::RejBoundedPoly + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +ExpandA : [32]Byte -> [k][ell]Tq +ExpandA ρ = A_hat where + A_hat = [[RejNTTPoly ρ' where + ρ' = ρ # IntegerToBytes`{1} s # IntegerToBytes`{1} r + | s <- [0..ell - 1]] + | r <- [0..k - 1]] + +/** + * ```repl + * :set tests=3 + * :check ExpandAEquivalence + * ``` + */ +ExpandAEquivalence : [32]Byte -> Bit +property ExpandAEquivalence ρ = Spec::ExpandA ρ == ExpandA ρ + +/** + * This is unchanged from `Spec`; it's duplicated to pull in the optimized + * dependencies in this file. + */ +ExpandS : [64]Byte -> ([ell]R, [k]R) +ExpandS ρ = (s1, s2) where + s1 = [RejBoundedPoly (ρ # IntegerToBytes`{2} r) | r <- [0..ell-1]] + s2 = [RejBoundedPoly (ρ # IntegerToBytes`{2} (r + `ell)) | r <- [0..k-1]] + +/** + * ```repl + * :set tests=5 + * :check ExpandSEquivalence + * ``` + */ +ExpandSEquivalence : [64]Byte -> Bit +property ExpandSEquivalence ρ = Spec::ExpandS ρ == ExpandS ρ + +/** + * This is almost identical to `Spec`; it's duplicated to pull in the optimized + * dependencies in this file but H`{32 * c} is replaced by H to avoid compilation + * errors. Moreover, the type of `c` is not used so it's removed. + */ +ExpandMask : [64]Byte -> Integer -> [ell]R +ExpandMask ρ μ = y where + + y = [BitUnpack`{γ1 - 1, γ1} v where + ρ' = ρ # IntegerToBytes`{2} (μ + r) + v = H ρ' + | r <- [0..ell - 1]] + +/** + * ```repl + * :set tests=5 + * :check ExpandMaskEquivalence + * ``` + */ +ExpandMaskEquivalence : [64]Byte -> Integer -> Bit +property ExpandMaskEquivalence ρ μ = Spec::ExpandMask ρ μ == ExpandMask ρ μ + +Power2Round = Spec::Power2Round + +Decompose = Spec::Decompose + +HighBits = Spec::HighBits + +LowBits = Spec::LowBits + +MakeHint = Spec::MakeHint + +UseHint = Spec::UseHint + +NTT = Spec::NTT + +NTTInv = Spec::NTTInv + +AddNTT = Spec::AddNTT + +MultiplyNTT = Spec::MultiplyNTT + +AddVectorNTT = Spec::AddVectorNTT + +ScalarVectorNTT = Spec::ScalarVectorNTT + +MatrixVectorNTT = Spec::MatrixVectorNTT + +(∘∘) : [k][ell]Tq -> [ell]Tq -> [k]Tq +(∘∘) M v = [sum [Mij * vj | Mij <- Mi | vj <- v] | Mi <- M]