-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLeanSha.lean
142 lines (124 loc) · 4.73 KB
/
LeanSha.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
namespace Sha
private def BitsInByte : UInt64 := 8
private def h₀α : UInt32 := 0x67452301
private def h₁α : UInt32 := 0xEFCDAB89
private def h₂α : UInt32 := 0x98BADCFE
private def h₃α : UInt32 := 0x10325476
private def h₄α : UInt32 := 0xC3D2E1F0
private def BlockSize : UInt32 := 512
private def OctetsInBlock : UInt32 := BlockSize / 8
private def OctetsInUInt64 : Nat := 8
private structure Msg where data : ByteArray deriving Inhabited
private structure Block where data : ByteArray deriving Inhabited
private def prepare (msg : ByteArray) : Msg := Id.run do
assert! msg.size ≤ 2^64 / BitsInByte.toNat
let mut msg' := msg
msg' := msg'.push 0b10000000
msg' := padToBlockSizeMinusOne msg'
msg' := msg'.append lengthBigEndian
⟨msg'⟩
where
padToBlockSizeMinusOne (arr : ByteArray) : ByteArray :=
ByteArray.append arr ⟨⟨
List.replicate ((OctetsInBlock.toNat - (msg.size + OctetsInUInt64 + 1) % OctetsInBlock.toNat) % (OctetsInBlock.toNat)) 0x00
⟩⟩
lengthBigEndian : ByteArray :=
let sz : UInt64 := UInt64.ofNat msg.size * BitsInByte
⟨#[
(sz >>> (7 * BitsInByte)),
(sz >>> (6 * BitsInByte)),
(sz >>> (5 * BitsInByte)),
(sz >>> (4 * BitsInByte)),
(sz >>> (3 * BitsInByte)),
(sz >>> (2 * BitsInByte)),
(sz >>> (1 * BitsInByte)),
(sz >>> (0 * BitsInByte))
].map (·.toUInt8)⟩
private def blocks (msg : Msg) : Array Block :=
(·.1) <| msg.data.data.foldl (init := (#[], ByteArray.empty, 0))
λ (res, stride, depth) octet =>
if depth % OctetsInBlock = OctetsInBlock - 1
then (res.push ⟨stride.push octet⟩, ⟨#[]⟩, depth + 1)
else (res, stride.push octet, depth + 1)
private def OctetsInUInt32 : Nat := 4
private def UInt32sOfBlock (block : Block) : Array UInt32 := Id.run do
let blockData : Array UInt8 := block.data.data
let mut ui32 : UInt32 := 0
let mut res : Array UInt32 := #[]
for byte in [: block.data.size : OctetsInUInt32] do
ui32 := (blockData.get! (byte + 0)).toUInt32 <<< (3 * BitsInByte).toUInt32 |||
(blockData.get! (byte + 1)).toUInt32 <<< (2 * BitsInByte).toUInt32 |||
(blockData.get! (byte + 2)).toUInt32 <<< (1 * BitsInByte).toUInt32 |||
(blockData.get! (byte + 3)).toUInt32 <<< (0 * BitsInByte).toUInt32
res := res.push ui32
res
private def ByteArrayOfUInt32 (ui32 : UInt32) : ByteArray :=
⟨#[
ui32 >>> (BitsInByte.toUInt32 * 3),
ui32 >>> (BitsInByte.toUInt32 * 2),
ui32 >>> (BitsInByte.toUInt32 * 1),
ui32 >>> (BitsInByte.toUInt32 * 0)
].map (·.toUInt8)⟩
private structure ShaST where
h₀ : UInt32
h₁ : UInt32
h₂ : UInt32
h₃ : UInt32
h₄ : UInt32
private def ShaST.mkInit : ShaST := {
h₀ := h₀α
h₁ := h₁α
h₂ := h₂α
h₃ := h₃α
h₄ := h₄α
}
private def NumIntegerPadBegin : Nat := 16
private def NumIntegerPadEnd : Nat := 79
private def rotl (n : UInt32) (k : Nat := 1) : UInt32 :=
assert! k ≤ 32
n <<< k.toUInt32 ||| (n >>> (32 - k.toUInt32))
private def bneg (n : UInt32) : UInt32 := n.xor 0xFFFFFFFF
def sha1 (msg : ByteArray) : ByteArray :=
let msgBlocks : Array Block := blocks ∘ prepare <| msg
Id.run do
let mut st : ShaST := ShaST.mkInit
for block in msgBlocks do
let mut uInts : Array UInt32 := UInt32sOfBlock block
for i in [NumIntegerPadBegin : NumIntegerPadEnd + 1] do
uInts := uInts.push <| rotl (
uInts[i - 3]! ^^^ uInts[i - 8]! ^^^
uInts[i - 14]! ^^^ uInts[i - 16]!
)
let mut a : UInt32 := st.h₀
let mut b : UInt32 := st.h₁
let mut c : UInt32 := st.h₂
let mut d : UInt32 := st.h₃
let mut e : UInt32 := st.h₄
let mut f : UInt32 := 0
let mut k : UInt32 := 0
for i in [0 : NumIntegerPadEnd + 1] do
if 0 ≤ i ∧ i ≤ 19 then
f := (b &&& c) ||| ((bneg b) &&& d)
k := 0x5A827999
else if 20 ≤ i ∧ i ≤ 39 then
f := b ^^^ c ^^^ d
k := 0x6ED9EBA1
else if 40 ≤ i ∧ i ≤ 59 then
f := (b &&& c) ||| (b &&& d) ||| (c &&& d)
k := 0x8F1BBCDC
else
f := b ^^^ c ^^^ d
k := 0xCA62C1D6
let temp := (rotl a 5) + f + e + k + uInts[i]!
e := d
d := c
c := rotl b 30
b := a
a := temp
st := ⟨st.h₀ + a, st.h₁ + b, st.h₂ + c, st.h₃ + d, st.h₄ + e⟩
pure ⟨
ByteArrayOfUInt32 st.h₀ ++ ByteArrayOfUInt32 st.h₁ ++
ByteArrayOfUInt32 st.h₂ ++ ByteArrayOfUInt32 st.h₃ ++
ByteArrayOfUInt32 st.h₄ |>.data
⟩
end Sha