Skip to content

Commit

Permalink
Implement optimized versions of functions at the third level
Browse files Browse the repository at this point in the history
These definitions are copied and pasted directly from the Specification.cry file.

The edited functions are: `KeyGen_internal`, `pkEncode`, `pkDecode`, `skEncode`, `skDecode`, `sigEncode`, `sigDecode`, `w1Encode`, `SampleInBall`, `ExpandA`, `ExpandS` and `ExpandMask`
  • Loading branch information
mariosge committed Jan 29, 2025
1 parent 5685782 commit 5949fb4
Showing 1 changed file with 179 additions and 17 deletions.
196 changes: 179 additions & 17 deletions Primitive/Asymmetric/Signature/ML_DSA/OptimizedSpecification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type η = P::η
type λ = P::λ
type γ1 = P::γ1
type γ2 = P::γ2
type τ = P::τ

type Byte = Spec::Byte

Expand Down Expand Up @@ -58,7 +59,34 @@ type PrivateKey = Spec::PrivateKey

type Signature = Spec::Signature

KeyGen_internal = Spec::KeyGen_internal
KeyGen_internal : [32]Byte -> (PublicKey, PrivateKey)
KeyGen_internal ξ = (pk, sk) where
// Step 1.
(ρ # ρ' # K) = H (ξ # IntegerToBytes`{1} `k # IntegerToBytes`{1} `ell)

// Step 3.
A_hat = ExpandA ρ
// Step 4.
(s1, s2) = ExpandS ρ'

// Explicitly typecast vectors in `R` to `Rq`.
s1' = castToRq s1
s2' = castToRq s2

// Step 5.
t = NTTInv_Vec (A_hat ∘∘ NTT_Vec s1') + s2'
// Step 6.
(t1, t0) = Power2Round t

// Step 8.
pk = pkEncode ρ t1
// Step 9.
tr = H pk
// Step 10.
sk = skEncode ρ K tr s1 s2 t0

KeyGen_internalEquivalence : [32]Byte -> Bit
property KeyGen_internalEquivalence ξ = Spec::KeyGen_internal ξ == KeyGen_internal ξ

Sign_internal = Spec::Sign_internal

Expand Down Expand Up @@ -129,31 +157,162 @@ HintBitPack = Spec::HintBitPack

HintBitUnpack = Spec::HintBitUnpack

pkEncode = Spec::pkEncode

pkDecode = Spec::pkDecode

skEncode = Spec::skEncode
pkEncode : [32]Byte -> [k]R -> PublicKey
pkEncode ρ t1 = pk where
pk = ρ # join [SimpleBitPack`{2 ^^ (width (q - 1) - d) - 1} (t1@i) | i <- [0..k-1]]

pkEncodeEquivalence : [32]Byte -> [k]R -> Bit
property pkEncodeEquivalence ρ t1 = Spec::pkEncode ρ t1 == pkEncode ρ t1

pkDecode : PublicKey -> ([32]Byte, [k]R)
pkDecode pk = (ρ, t1) where
// Step 1. We split off the single `ρ` byte, then separate the remaining
// bytes into the `k` components as described.
(ρ # zBytes) = pk
z = split zBytes
// Steps 2 - 4.
t1 = [SimpleBitUnpack`{2 ^^ (width (q - 1) - d) - 1} (z@i) | i <- [0..k-1]]

pkDecodeEquivalence : PublicKey -> Bit
property pkDecodeEquivalence pk = Spec::pkDecode pk == pkDecode pk

skEncode : [32]Byte -> [32]Byte -> [64]Byte -> [ell]R -> [k]R -> [k]R
-> PrivateKey
skEncode ρ K tr s1 s2 t0 = sk9 where
// Note: `sk#` indicates the value of `sk` at Step `#`.
// Step 1.
sk1 = ρ # K # tr
// Steps 2 - 4.
sk3 = sk1 # join [BitPack`{η, η} (s1@i) | i <- [0..ell-1]]
// Steps 5 - 7.
sk6 = sk3 # join [BitPack`{η, η} (s2@i) | i <- [0..k-1]]
// Steps 8 - 10.
sk9 = sk6 #
join [BitPack`{2^^(d - 1) - 1, 2^^(d - 1)} (t0@i) | i <- [0..k-1]]

skEncodeEquivalence : [32]Byte -> [32]Byte -> [64]Byte -> [ell]R -> [k]R -> [k]R -> Bit
property skEncodeEquivalence ρ K tr s1 s2 t0 = Spec::skEncode ρ K tr s1 s2 t0 == skEncode ρ K tr s1 s2 t0

skDecode : PrivateKey -> ([32]Byte, [32]Byte, [64]Byte, [ell]R, [k]R, [k]R)
skDecode sk = (ρ, K, tr, s1, s2, t0) where
// Step 1. We split off the six components, then further separate `y`, `z`,
// and `w` into their two dimensions.
(ρ # K # tr # yBytes # zBytes # wBytes) = sk
y = split`{ell} yBytes
z = split`{k} zBytes
w = split`{k} wBytes

// Steps 2 - 4.
s1 = [BitUnpack`{η, η} (y@i) | i <- [0..ell-1]]
// Steps 5 - 7.
s2 = [BitUnpack`{η, η} (z@i) | i <- [0..k-1]]
// Steps 8 - 10.
t0 = [BitUnpack`{2^^(d - 1) - 1, 2^^(d - 1)} (w@i) | i <- [0..k-1]]

skDecodeEquivalence : PrivateKey -> Bit
property skDecodeEquivalence sk = Spec::skDecode sk == skDecode sk

sigEncode : [λ / 4]Byte -> [ell]R -> [k]R2 -> Signature
sigEncode c_til z h = σ where
// Note that `σ#` indicates the value of `σ` at Step `#`.
// Step 1.
σ1 = c_til
// Step 2 - 4.
σ3 = σ1 # join [Spec::BitPack`{γ1 - 1, γ1} (z@i) | i <- [0..ell-1]]
// Step 5.
σ = σ3 # HintBitPack h

sigDecode : Signature -> ([λ / 4]Byte, [ell]R, Option ([k]R2))
sigDecode σ = (c_til, z, h) where
// Step 1. We separate into bytes, then further split `x` into its two
// dimensions.
(c_til # xBytes # y) = σ
x = split`{ell} xBytes

// Step 2 - 4.
z = [BitUnpack`{γ1 - 1, γ1} (x@i) | i <- [0..ell-1]]
// Step 5.
h = HintBitUnpack y

w1Encode : [k]R -> [32 * k * width ((q - 1) / (2 * γ2) - 1)]Byte
w1Encode w1 = w1_til where
w1_til = join
[SimpleBitPack`{(q - 1) / (2 * γ2) - 1} (w1@i) | i <- [0..k-1]]

w1EncodeEquivalence : [k]R -> Bit
property w1EncodeEquivalence w1 = Spec::w1Encode w1 == w1Encode w1

SampleInBall : [λ / 4]Byte -> R
SampleInBall ρ = cFinal where
// Step 1.
c0 = zero
// Steps 2 - 3.
ctx_0 = H ρ
// Step 4.
((s : [8]Byte) # ctx_1) = ctx_0
// Step 5.
h = BytesToBits s

// Steps 7 - 10. Uses recursion instead of a loop to sample bytes from the
// hash stream, returning the first one that's in the range `[0, i]`.
sample : [inf]Byte -> Byte -> (Byte, [inf]Byte)
sample ([j] # ctx) i =
if j > i then
sample ctx i
else (j, ctx)

// Steps 6 - 13. Computes the value of `c` and the updated `ctx` at each
// iteration of the loop.
cAndCtx = [(c0, ctx_1)] # [(c'', ctx') where
// Steps 7 - 10.
(j, ctx') = sample ctx (fromInteger i)
// Step 11.
c' = update c i (c@j)
// Step 12. In Cryptol, we need to manually convert the exponent
// from a `Bit` to a numeric type.
hiτ = if (h @ (i + `τ - 256)) then 1 else 0 : Integer
c'' = update c' j ((-1)^^hiτ)

| i <- [256 - τ..255]
| (c, ctx) <- cAndCtx]

(cFinal, _) = cAndCtx ! 0

SampleInBallEquivalence : [λ / 4]Byte -> Bit
property SampleInBallEquivalence ρ = Spec::SampleInBall ρ == SampleInBall ρ

skDecode = Spec::skDecode

sigEncode = Spec::sigEncode
RejNTTPoly = Spec::RejNTTPoly

sigDecode = Spec::sigDecode
RejBoundedPoly = Spec::RejBoundedPoly

w1Encode = Spec::w1Encode
ExpandA : [32]Byte -> [k][ell]Tq
ExpandA ρ = A_hat where
A_hat = [[RejNTTPoly ρ' where
ρ' = ρ # IntegerToBytes`{1} s # IntegerToBytes`{1} r
| s <- [0..ell - 1]]
| r <- [0..k - 1]]

SampleInBall = Spec::SampleInBall
ExpandAEquivalence : [32]Byte -> Bit
property ExpandAEquivalence ρ = Spec::ExpandA ρ == ExpandA ρ

RejNTTPoly = Spec::RejNTTPoly
ExpandS : [64]Byte -> ([ell]R, [k]R)
ExpandS ρ = (s1, s2) where
s1 = [RejBoundedPoly (ρ # IntegerToBytes`{2} r) | r <- [0..ell-1]]
s2 = [RejBoundedPoly (ρ # IntegerToBytes`{2} (r + `ell)) | r <- [0..k-1]]

RejBoundedPoly = Spec::RejBoundedPoly
ExpandSEquivalence : [64]Byte -> Bit
property ExpandSEquivalence ρ = Spec::ExpandS ρ == ExpandS ρ

ExpandA = Spec::ExpandA
ExpandMask : [64]Byte -> Integer -> [ell]R
ExpandMask ρ μ = y where

ExpandS = Spec::ExpandS
y = [BitUnpack`{γ1 - 1, γ1} v where
ρ' = ρ # IntegerToBytes`{2} (μ + r)
v = H ρ'
| r <- [0..ell - 1]]

ExpandMask = Spec::ExpandMask
ExpandMaskEquivalence : [64]Byte -> Integer -> Bit
property ExpandMaskEquivalence ρ μ = Spec::ExpandMask ρ μ == ExpandMask ρ μ

Power2Round = Spec::Power2Round

Expand All @@ -180,3 +339,6 @@ AddVectorNTT = Spec::AddVectorNTT
ScalarVectorNTT = Spec::ScalarVectorNTT

MatrixVectorNTT = Spec::MatrixVectorNTT

(∘∘) : [k][ell]Tq -> [ell]Tq -> [k]Tq
(∘∘) M v = [sum [Mij * vj | Mij <- Mi | vj <- v] | Mi <- M]

0 comments on commit 5949fb4

Please sign in to comment.