diff --git a/sha1/src/compress/aarch64.rs b/sha1/src/compress/aarch64.rs index 5952d1f6..ad836b6a 100644 --- a/sha1/src/compress/aarch64.rs +++ b/sha1/src/compress/aarch64.rs @@ -7,11 +7,182 @@ // > Enable SHA1 and SHA256 support. cpufeatures::new!(sha1_hwcap, "sha2"); +const K: [u32; 4] = [0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xCA62C1D6]; + +#[target_feature(enable = "sha2,neon")] +unsafe fn compress_sha1_neon(state: &mut [u32; 5], blocks: &[[u8; 64]]) { + use core::arch::aarch64::*; + + let mut abcd = vld1q_u32(state.as_ptr()); + let mut e0 = state[4]; + let mut e1; + for block in blocks { + let abcd_cpy = abcd; + let e0_cpy = e0; + + let block_ptr: *const u32 = block.as_ptr().cast(); + let mut msg0 = vld1q_u32(block_ptr); + let mut msg1 = vld1q_u32(block_ptr.add(4)); + let mut msg2 = vld1q_u32(block_ptr.add(8)); + let mut msg3 = vld1q_u32(block_ptr.add(12)); + + // Reverse byte order + msg0 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(msg0))); + msg1 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(msg1))); + msg2 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(msg2))); + msg3 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(msg3))); + + let mut tmp0 = vaddq_u32(msg0, vdupq_n_u32(K[0])); + let mut tmp1 = vaddq_u32(msg1, vdupq_n_u32(K[0])); + + // Rounds 0-3 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, vdupq_n_u32(K[0])); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 4-7 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, vdupq_n_u32(K[0])); + msg0 = vsha1su1q_u32(msg0, msg3); + msg1 = vsha1su0q_u32(msg1, msg2, msg3); + + // Rounds 8-11 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg0, vdupq_n_u32(K[0])); + msg1 = vsha1su1q_u32(msg1, msg0); + msg2 = vsha1su0q_u32(msg2, msg3, msg0); + + // Rounds 12-15 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg1, vdupq_n_u32(K[1])); + msg2 = vsha1su1q_u32(msg2, msg1); + msg3 = vsha1su0q_u32(msg3, msg0, msg1); + + // Rounds 16-19 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, vdupq_n_u32(K[1])); + msg3 = vsha1su1q_u32(msg3, msg2); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 20-23 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, vdupq_n_u32(K[1])); + msg0 = vsha1su1q_u32(msg0, msg3); + msg1 = vsha1su0q_u32(msg1, msg2, msg3); + + // Rounds 24-27 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg0, vdupq_n_u32(K[1])); + msg1 = vsha1su1q_u32(msg1, msg0); + msg2 = vsha1su0q_u32(msg2, msg3, msg0); + + // Rounds 28-31 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg1, vdupq_n_u32(K[1])); + msg2 = vsha1su1q_u32(msg2, msg1); + msg3 = vsha1su0q_u32(msg3, msg0, msg1); + + // Rounds 32-35 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, vdupq_n_u32(K[2])); + msg3 = vsha1su1q_u32(msg3, msg2); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 36-39 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, vdupq_n_u32(K[2])); + msg0 = vsha1su1q_u32(msg0, msg3); + msg1 = vsha1su0q_u32(msg1, msg2, msg3); + + // Rounds 40-43 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg0, vdupq_n_u32(K[2])); + msg1 = vsha1su1q_u32(msg1, msg0); + msg2 = vsha1su0q_u32(msg2, msg3, msg0); + + // Rounds 44-47 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg1, vdupq_n_u32(K[2])); + msg2 = vsha1su1q_u32(msg2, msg1); + msg3 = vsha1su0q_u32(msg3, msg0, msg1); + + // Rounds 48-51 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, vdupq_n_u32(K[2])); + msg3 = vsha1su1q_u32(msg3, msg2); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 52-55 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, vdupq_n_u32(K[3])); + msg0 = vsha1su1q_u32(msg0, msg3); + msg1 = vsha1su0q_u32(msg1, msg2, msg3); + + // Rounds 56-59 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg0, vdupq_n_u32(K[3])); + msg1 = vsha1su1q_u32(msg1, msg0); + msg2 = vsha1su0q_u32(msg2, msg3, msg0); + + // Rounds 60-63 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg1, vdupq_n_u32(K[3])); + msg2 = vsha1su1q_u32(msg2, msg1); + msg3 = vsha1su0q_u32(msg3, msg0, msg1); + + // Rounds 64-67 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, vdupq_n_u32(K[3])); + msg3 = vsha1su1q_u32(msg3, msg2); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 68-71 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, vdupq_n_u32(K[3])); + msg0 = vsha1su1q_u32(msg0, msg3); + let _ = msg0; + + // Rounds 72-75 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e0, tmp0); + + // Rounds 76-79 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + + e0 += e0.wrapping_add(e0_cpy); + abcd = vaddq_u32(abcd_cpy, abcd); + } + + // Save state + vst1q_u32(state.as_mut_ptr(), abcd); + state[4] = e0; +} + pub fn compress(state: &mut [u32; 5], blocks: &[[u8; 64]]) { - // TODO: Replace with https://github.com/rust-lang/rfcs/pull/2725 - // after stabilization + // TODO: Replace with https://github.com/rust-lang/rfcs/pull/2725 after stabilization if sha1_hwcap::get() { - sha1_asm::compress(state, blocks); + unsafe { + compress_sha1_neon(state, blocks); + } } else { super::soft::compress(state, blocks); }