diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index 4a0d1261..80681682 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -10,7 +10,7 @@ jobs: strategy: fail-fast: false matrix: - fuzz_target: [decode_rnd, encode_decode, parse_hrp] + fuzz_target: [berlekamp_massey, correct_bech32, correct_codex32, decode_rnd, encode_decode, parse_hrp] steps: - name: Install test dependencies run: sudo apt-get update -y && sudo apt-get install -y binutils-dev libunwind8-dev libcurl4-openssl-dev libelf-dev libdw-dev cmake gcc libiberty-dev diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index b3d21a63..640742f1 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -17,6 +17,18 @@ bech32 = { path = ".." } [workspace] members = ["."] +[[bin]] +name = "berlekamp_massey" +path = "fuzz_targets/berlekamp_massey.rs" + +[[bin]] +name = "correct_bech32" +path = "fuzz_targets/correct_bech32.rs" + +[[bin]] +name = "correct_codex32" +path = "fuzz_targets/correct_codex32.rs" + [[bin]] name = "decode_rnd" path = "fuzz_targets/decode_rnd.rs" diff --git a/fuzz/fuzz_targets/berlekamp_massey.rs b/fuzz/fuzz_targets/berlekamp_massey.rs new file mode 100644 index 00000000..1caf8e84 --- /dev/null +++ b/fuzz/fuzz_targets/berlekamp_massey.rs @@ -0,0 +1,58 @@ +use bech32::primitives::LfsrIter; +use bech32::Fe32; +use honggfuzz::fuzz; + +fn do_test(data: &[u8]) { + for ch in data { + if *ch >= 32 { + return; + } + } + if data.is_empty() || data.len() > 1_000 { + return; + } + + let mut iv = Vec::with_capacity(data.len()); + for ch in data { + iv.push(Fe32::try_from(*ch).unwrap()); + } + + for (i, d) in LfsrIter::berlekamp_massey(&iv).take(data.len()).enumerate() { + assert_eq!(data[i], d.to_u8()); + } +} + +fn main() { + loop { + fuzz!(|data| { + do_test(data); + }); + } +} + +#[cfg(test)] +mod tests { + fn extend_vec_from_hex(hex: &str, out: &mut Vec) { + let mut b = 0; + for (idx, c) in hex.as_bytes().iter().filter(|&&c| c != b'\n').enumerate() { + b <<= 4; + match *c { + b'A'..=b'F' => b |= c - b'A' + 10, + b'a'..=b'f' => b |= c - b'a' + 10, + b'0'..=b'9' => b |= c - b'0', + _ => panic!("Bad hex"), + } + if (idx & 1) == 1 { + out.push(b); + b = 0; + } + } + } + + #[test] + fn duplicate_crash() { + let mut a = Vec::new(); + extend_vec_from_hex("00", &mut a); + super::do_test(&a); + } +} diff --git a/fuzz/fuzz_targets/correct_bech32.rs b/fuzz/fuzz_targets/correct_bech32.rs new file mode 100644 index 00000000..ee823727 --- /dev/null +++ b/fuzz/fuzz_targets/correct_bech32.rs @@ -0,0 +1,112 @@ +use std::collections::HashMap; + +use bech32::primitives::correction::CorrectableError as _; +use bech32::primitives::decode::CheckedHrpstring; +use bech32::{Bech32, Fe32}; +use honggfuzz::fuzz; + +// coinbase output of block 862290 +static CORRECT: &[u8; 62] = b"bc1qwzrryqr3ja8w7hnja2spmkgfdcgvqwp5swz4af4ngsjecfz0w0pqud7k38"; + +fn do_test(data: &[u8]) { + if data.is_empty() || data.len() % 2 == 1 { + return; + } + + let mut any_actual_errors = false; + let mut e2t = 0; + let mut erasures = Vec::with_capacity(CORRECT.len()); + // Start with a correct string + let mut hrpstring = *CORRECT; + // ..then mangle it + let mut errors = HashMap::with_capacity(data.len() / 2); + for sl in data.chunks_exact(2) { + let idx = usize::from(sl[0]) & 0x7f; + if idx >= CORRECT.len() - 3 { + return; + } + let offs = match Fe32::try_from(sl[1]) { + Ok(fe) => fe, + Err(_) => return, + }; + + hrpstring[idx + 3] = + (Fe32::from_char(hrpstring[idx + 3].into()).unwrap() + offs).to_char() as u8; + + if errors.insert(CORRECT.len() - (idx + 3) - 1, offs).is_some() { + return; + } + if sl[0] & 0x80 == 0x80 { + // We might push "dummy" errors which are erasures that aren't actually wrong. + // If we do this too many times, we'll exceed the singleton bound so correction + // will fail, but as long as we're within the bound everything should "work", + // in the sense that there will be no crashes and the error corrector will + // just yield an error with value Q. + erasures.push(CORRECT.len() - (idx + 3) - 1); + e2t += 1; + if offs != Fe32::Q { + any_actual_errors = true; + } + } else if offs != Fe32::Q { + any_actual_errors = true; + e2t += 2; + } + } + // We need _some_ errors. + if !any_actual_errors { + return; + } + + let s = unsafe { core::str::from_utf8_unchecked(&hrpstring) }; + let mut correct_ctx = CheckedHrpstring::new::(s) + .unwrap_err() + .correction_context::() + .unwrap(); + + correct_ctx.add_erasures(&erasures); + + let iter = correct_ctx.bch_errors(); + if e2t <= 3 { + for (idx, fe) in iter.unwrap() { + assert_eq!(errors.remove(&idx), Some(fe)); + } + for val in errors.values() { + assert_eq!(*val, Fe32::Q); + } + } +} + +fn main() { + loop { + fuzz!(|data| { + do_test(data); + }); + } +} + +#[cfg(test)] +mod tests { + fn extend_vec_from_hex(hex: &str, out: &mut Vec) { + let mut b = 0; + for (idx, c) in hex.as_bytes().iter().filter(|&&c| c != b'\n').enumerate() { + b <<= 4; + match *c { + b'A'..=b'F' => b |= c - b'A' + 10, + b'a'..=b'f' => b |= c - b'a' + 10, + b'0'..=b'9' => b |= c - b'0', + _ => panic!("Bad hex"), + } + if (idx & 1) == 1 { + out.push(b); + b = 0; + } + } + } + + #[test] + fn duplicate_crash() { + let mut a = Vec::new(); + extend_vec_from_hex("04010008", &mut a); + super::do_test(&a); + } +} diff --git a/fuzz/fuzz_targets/correct_codex32.rs b/fuzz/fuzz_targets/correct_codex32.rs new file mode 100644 index 00000000..f726a0f2 --- /dev/null +++ b/fuzz/fuzz_targets/correct_codex32.rs @@ -0,0 +1,137 @@ +use std::collections::HashMap; + +use bech32::primitives::correction::CorrectableError as _; +use bech32::primitives::decode::CheckedHrpstring; +use bech32::{Checksum, Fe1024, Fe32}; +use honggfuzz::fuzz; + +/// The codex32 checksum algorithm, defined in BIP-93. +/// +/// Used in this fuzztest because it can correct up to 4 errors, vs bech32 which +/// can correct only 1. Should exhibit more interesting behavior. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Codex32 {} + +impl Checksum for Codex32 { + type MidstateRepr = u128; + type CorrectionField = Fe1024; + const ROOT_GENERATOR: Self::CorrectionField = Fe1024::new([Fe32::_9, Fe32::_9]); + const ROOT_EXPONENTS: core::ops::RangeInclusive = 9..=16; + + const CHECKSUM_LENGTH: usize = 13; + const CODE_LENGTH: usize = 93; + // Copied from BIP-93 + const GENERATOR_SH: [u128; 5] = [ + 0x19dc500ce73fde210, + 0x1bfae00def77fe529, + 0x1fbd920fffe7bee52, + 0x1739640bdeee3fdad, + 0x07729a039cfc75f5a, + ]; + const TARGET_RESIDUE: u128 = 0x10ce0795c2fd1e62a; +} + +static CORRECT: &[u8; 48] = b"ms10testsxxxxxxxxxxxxxxxxxxxxxxxxxx4nzvca9cmczlw"; + +fn do_test(data: &[u8]) { + if data.is_empty() || data.len() % 2 == 1 { + return; + } + + let mut any_actual_errors = false; + let mut e2t = 0; + let mut erasures = Vec::with_capacity(CORRECT.len()); + // Start with a correct string + let mut hrpstring = *CORRECT; + // ..then mangle it + let mut errors = HashMap::with_capacity(data.len() / 2); + for sl in data.chunks_exact(2) { + let idx = usize::from(sl[0]) & 0x7f; + if idx >= CORRECT.len() - 3 { + return; + } + let offs = match Fe32::try_from(sl[1]) { + Ok(fe) => fe, + Err(_) => return, + }; + + hrpstring[idx + 3] = + (Fe32::from_char(hrpstring[idx + 3].into()).unwrap() + offs).to_char() as u8; + + if errors.insert(CORRECT.len() - (idx + 3) - 1, offs).is_some() { + return; + } + if sl[0] & 0x80 == 0x80 { + // We might push "dummy" errors which are erasures that aren't actually wrong. + // If we do this too many times, we'll exceed the singleton bound so correction + // will fail, but as long as we're within the bound everything should "work", + // in the sense that there will be no crashes and the error corrector will + // just yield an error with value Q. + erasures.push(CORRECT.len() - (idx + 3) - 1); + e2t += 1; + if offs != Fe32::Q { + any_actual_errors = true; + } + } else if offs != Fe32::Q { + any_actual_errors = true; + e2t += 2; + } + } + // We need _some_ errors. + if !any_actual_errors { + return; + } + + let s = unsafe { core::str::from_utf8_unchecked(&hrpstring) }; + let mut correct_ctx = CheckedHrpstring::new::(s) + .unwrap_err() + .correction_context::() + .unwrap(); + + correct_ctx.add_erasures(&erasures); + + let iter = correct_ctx.bch_errors(); + if e2t <= 8 { + for (idx, fe) in iter.unwrap() { + assert_eq!(errors.remove(&idx), Some(fe)); + } + for val in errors.values() { + assert_eq!(*val, Fe32::Q); + } + } +} + +fn main() { + loop { + fuzz!(|data| { + do_test(data); + }); + } +} + +#[cfg(test)] +mod tests { + fn extend_vec_from_hex(hex: &str, out: &mut Vec) { + let mut b = 0; + for (idx, c) in hex.as_bytes().iter().filter(|&&c| c != b'\n').enumerate() { + b <<= 4; + match *c { + b'A'..=b'F' => b |= c - b'A' + 10, + b'a'..=b'f' => b |= c - b'a' + 10, + b'0'..=b'9' => b |= c - b'0', + _ => panic!("Bad hex"), + } + if (idx & 1) == 1 { + out.push(b); + b = 0; + } + } + } + + #[test] + fn duplicate_crash() { + let mut a = Vec::new(); + extend_vec_from_hex("8c00a10091039e0185008000831f8e0f", &mut a); + super::do_test(&a); + } +} diff --git a/src/lib.rs b/src/lib.rs index 64969b4d..4def6ece 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,7 +105,7 @@ //! type MidstateRepr = u128; //! type CorrectionField = bech32::primitives::gf32_ext::Fe32Ext<2>; //! const ROOT_GENERATOR: Self::CorrectionField = Fe1024::new([Fe32::_9, Fe32::_9]); -//! const ROOT_EXPONENTS: core::ops::RangeInclusive = 77..=84; +//! const ROOT_EXPONENTS: core::ops::RangeInclusive = 9..=16; //! //! const CHECKSUM_LENGTH: usize = 13; //! const CODE_LENGTH: usize = 93; diff --git a/src/primitives/correction.rs b/src/primitives/correction.rs index 7918ac8a..e965f8ec 100644 --- a/src/primitives/correction.rs +++ b/src/primitives/correction.rs @@ -6,11 +6,16 @@ //! equation to identify the error values, in a BCH-encoded string. //! +use core::convert::TryInto; +use core::marker::PhantomData; + use crate::primitives::decode::{ CheckedHrpstringError, ChecksumError, InvalidResidueError, SegwitHrpstringError, }; +use crate::primitives::{Field as _, FieldVec, LfsrIter, Polynomial}; #[cfg(feature = "alloc")] use crate::DecodeError; +use crate::{Checksum, Fe32}; /// **One more than** the maximum length (in characters) of a checksum which /// can be error-corrected without an allocator. @@ -57,6 +62,26 @@ pub trait CorrectableError { /// /// This is the function that implementors should implement. fn residue_error(&self) -> Option<&InvalidResidueError>; + + /// Wrapper around [`Self::residue_error`] that outputs a correction context. + /// + /// Will return None if the error is not a correctable one, or if the **alloc** + /// feature is disabled and the checksum is too large. See the documentation + /// for [`NO_ALLOC_MAX_LENGTH`] for more information. + /// + /// This is the function that users should call. + fn correction_context(&self) -> Option> { + #[cfg(not(feature = "alloc"))] + if Ck::CHECKSUM_LENGTH >= NO_ALLOC_MAX_LENGTH { + return None; + } + + self.residue_error().map(|e| Corrector { + erasures: FieldVec::new(), + residue: e.residue(), + phantom: PhantomData, + }) + } } impl CorrectableError for InvalidResidueError { @@ -104,3 +129,273 @@ impl CorrectableError for DecodeError { } } } + +/// An error-correction context. +pub struct Corrector { + erasures: FieldVec, + residue: Polynomial, + phantom: PhantomData, +} + +impl Corrector { + /// A bound on the number of errors and erasures (errors with known location) + /// can be corrected by this corrector. + /// + /// Returns N such that, given E errors and X erasures, corection is possible + /// iff 2E + X <= N. + pub fn singleton_bound(&self) -> usize { + // d - 1, where d = [number of consecutive roots] + 2 + Ck::ROOT_EXPONENTS.end() - Ck::ROOT_EXPONENTS.start() + 1 + } + + /// TODO + pub fn add_erasures(&mut self, locs: &[usize]) { + for loc in locs { + // If the user tries to add too many erasures, just ignore them. In + // this case error correction is guaranteed to fail anyway, because + // they will have exceeded the singleton bound. (Otherwise, the + // singleton bound, which is always <= the checksum length, must be + // greater than NO_ALLOC_MAX_LENGTH. So the checksum length must be + // greater than NO_ALLOC_MAX_LENGTH. Then correction will still fail.) + #[cfg(not(feature = "alloc"))] + if self.erasures.len() == NO_ALLOC_MAX_LENGTH { + break; + } + self.erasures.push(*loc); + } + } + + /// Returns an iterator over the errors in the string. + /// + /// Returns `None` if it can be determined that there are too many errors to be + /// corrected. However, returning an iterator from this function does **not** + /// imply that the intended string can be determined. It only implies that there + /// is a unique closest correct string to the erroneous string, and gives + /// instructions for finding it. + /// + /// If the input string has sufficiently many errors, this unique closest correct + /// string may not actually be the intended string. + pub fn bch_errors(&self) -> Option> { + // 1. Compute all syndromes by evaluating the residue at each power of the generator. + let syndromes: Polynomial<_> = Ck::ROOT_GENERATOR + .powers_range(Ck::ROOT_EXPONENTS) + .map(|rt| self.residue.evaluate(&rt)) + .collect(); + + // 1a. Compute the "Forney syndrome polynomial" which is the product of the syndrome + // polynomial and the erasure locator. This "erases the erasures" so that B-M + // can find only the errors. + let mut erasure_locator = Polynomial::with_monic_leading_term(&[]); // 1 + for loc in &self.erasures { + let factor: Polynomial<_> = + [Ck::CorrectionField::ONE, -Ck::ROOT_GENERATOR.powi(*loc as i64)] + .iter() + .cloned() + .collect(); // alpha^-ix - 1 + erasure_locator = erasure_locator.mul_mod_x_d(&factor, usize::MAX); + } + let forney_syndromes = erasure_locator.convolution(&syndromes); + + // 2. Use the Berlekamp-Massey algorithm to find the connection polynomial of the + // LFSR that generates these syndromes. For magical reasons this will be equal + // to the error locator polynomial for the syndrome. + let lfsr = LfsrIter::berlekamp_massey(&forney_syndromes.as_inner()[..]); + let conn = lfsr.coefficient_polynomial(); + + // 3. The connection polynomial is the error locator polynomial. Use this to get + // the errors. + if erasure_locator.degree() + 2 * conn.degree() <= self.singleton_bound() { + // 3a. Compute the "errata locator" which is the product of the error locator + // and the erasure locator. Note that while we used the Forney syndromes + // when calling the BM algorithm, in all other cases we use the ordinary + // unmodified syndromes. + let errata_locator = conn.mul_mod_x_d(&erasure_locator, usize::MAX); + Some(ErrorIterator { + evaluator: errata_locator.mul_mod_x_d(&syndromes, self.singleton_bound()), + locator_derivative: errata_locator.formal_derivative(), + erasures: &self.erasures[..], + errors: conn.find_nonzero_distinct_roots(Ck::ROOT_GENERATOR), + a: Ck::ROOT_GENERATOR, + c: *Ck::ROOT_EXPONENTS.start(), + }) + } else { + None + } + } +} + +/// An iterator over the errors in a string. +/// +/// The errors will be yielded as `(usize, Fe32)` tuples. +/// +/// The first component is a **negative index** into the string. So 0 represents +/// the last element, 1 the second-to-last, and so on. +/// +/// The second component is an element to **add to** the element at the given +/// location in the string. +/// +/// The maximum index is one less than [`Checksum::CODE_LENGTH`], regardless of the +/// actual length of the string. Therefore it is not safe to simply subtract the +/// length of the string from the returned index; you must first check that the +/// index makes sense. If the index exceeds the length of the string or implies that +/// an error occurred in the HRP, the string should simply be rejected as uncorrectable. +/// +/// Out-of-bound error locations will not occur "naturally", in the sense that they +/// will happen with extremely low probability for a string with a valid HRP and a +/// uniform error pattern. (The probability is 32^-n, where n is the size of the +/// range [`Checksum::ROOT_EXPONENTS`], so it is not neglible but is very small for +/// most checksums.) However, it is easy to construct adversarial inputs that will +/// exhibit this behavior, so you must take it into account. +/// +/// Out-of-bound error locations may occur naturally in the case of a string with a +/// corrupted HRP, because for checksumming purposes the HRP is treated as twice as +/// many field elements as characters, plus one. If the correct HRP is known, the +/// caller should fix this before attempting error correction. If it is unknown, +/// the caller cannot assume anything about the intended checksum, and should not +/// attempt error correction. +pub struct ErrorIterator<'c, Ck: Checksum> { + evaluator: Polynomial, + locator_derivative: Polynomial, + erasures: &'c [usize], + errors: super::polynomial::RootIter, + a: Ck::CorrectionField, + c: usize, +} + +impl<'c, Ck: Checksum> Iterator for ErrorIterator<'c, Ck> { + type Item = (usize, Fe32); + + fn next(&mut self) -> Option { + // Compute -i, which is the location we will return to the user. + let neg_i = if self.erasures.is_empty() { + match self.errors.next() { + None => return None, + Some(0) => 0, + Some(x) => Ck::ROOT_GENERATOR.multiplicative_order() - x, + } + } else { + let pop = self.erasures[0]; + self.erasures = &self.erasures[1..]; + pop + }; + + // Forney's equation, as described in https://en.wikipedia.org/wiki/BCH_code#Forney_algorithm + // + // It is rendered as + // + // evaluator(a^-i) + // e_k = - ----------------------------------------- + // (a^i)^(c - 1)) locator_derivative(a^-i) + // + // where here a is `Ck::ROOT_GENERATOR`, c is the first element of the range + // `Ck::ROOT_EXPONENTS`, and both evalutor and locator_derivative are polynomials + // which are computed when constructing the ErrorIterator. + + let a_i = self.a.powi(neg_i as i64); + let a_neg_i = a_i.clone().multiplicative_inverse(); + + let num = self.evaluator.evaluate(&a_neg_i); + let den = a_i.powi(self.c as i64 - 1) * self.locator_derivative.evaluate(&a_neg_i); + let ret = -num / den; + match ret.try_into() { + Ok(ret) => Some((neg_i, ret)), + Err(_) => unreachable!("error guaranteed to lie in base field"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::primitives::decode::SegwitHrpstring; + use crate::Bech32; + + #[test] + fn bech32() { + // Last x should be q + let s = "bc1qar0srrr7xfkvy5l643lydnw9re59gtzzwf5mdx"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let mut ctx = e.correction_context::().unwrap(); + let mut iter = ctx.bch_errors().unwrap(); + assert_eq!(iter.next(), Some((0, Fe32::X))); + assert_eq!(iter.next(), None); + + ctx.add_erasures(&[0]); + let mut iter = ctx.bch_errors().unwrap(); + assert_eq!(iter.next(), Some((0, Fe32::X))); + assert_eq!(iter.next(), None); + } + } + + // f should be z, 6 chars from the back. + let s = "bc1qar0srrr7xfkvy5l643lydnw9re59gtzfwf5mdq"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let mut ctx = e.correction_context::().unwrap(); + let mut iter = ctx.bch_errors().unwrap(); + assert_eq!(iter.next(), Some((6, Fe32::T))); + assert_eq!(iter.next(), None); + + ctx.add_erasures(&[6]); + let mut iter = ctx.bch_errors().unwrap(); + assert_eq!(iter.next(), Some((6, Fe32::T))); + assert_eq!(iter.next(), None); + } + } + + // 20 characters from the end there is a q which should be 3 + let s = "bc1qar0srrr7xfkvy5l64qlydnw9re59gtzzwf5mdq"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let ctx = e.correction_context::().unwrap(); + let mut iter = ctx.bch_errors().unwrap(); + + assert_eq!(iter.next(), Some((20, Fe32::_3))); + assert_eq!(iter.next(), None); + } + } + + // Two errors; cannot correct. + let s = "bc1qar0srrr7xfkvy5l64qlydnw9re59gtzzwf5mdx"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let mut ctx = e.correction_context::().unwrap(); + assert!(ctx.bch_errors().is_none()); + + // But we can correct it if we inform where an error is. + ctx.add_erasures(&[0]); + let mut iter = ctx.bch_errors().unwrap(); + assert_eq!(iter.next(), Some((0, Fe32::X))); + assert_eq!(iter.next(), Some((20, Fe32::_3))); + assert_eq!(iter.next(), None); + + ctx.add_erasures(&[20]); + let mut iter = ctx.bch_errors().unwrap(); + assert_eq!(iter.next(), Some((0, Fe32::X))); + assert_eq!(iter.next(), Some((20, Fe32::_3))); + assert_eq!(iter.next(), None); + } + } + + // In fact, if we know the locations, we can correct up to 3 errors. + let s = "bc1q9r0srrr7xfkvy5l64qlydnw9re59gtzzwf5mdx"; + match SegwitHrpstring::new(s) { + Ok(_) => panic!("{} successfully, and wrongly, parsed", s), + Err(e) => { + let mut ctx = e.correction_context::().unwrap(); + ctx.add_erasures(&[37, 0, 20]); + let mut iter = ctx.bch_errors().unwrap(); + + assert_eq!(iter.next(), Some((37, Fe32::C))); + assert_eq!(iter.next(), Some((0, Fe32::X))); + assert_eq!(iter.next(), Some((20, Fe32::_3))); + assert_eq!(iter.next(), None); + } + } + } +} diff --git a/src/primitives/decode.rs b/src/primitives/decode.rs index 0cb78b4a..b4ca5e38 100644 --- a/src/primitives/decode.rs +++ b/src/primitives/decode.rs @@ -1021,6 +1021,17 @@ impl InvalidResidueError { pub fn matches_bech32_checksum(&self) -> bool { self.actual == Polynomial::from_residue(Bech32::TARGET_RESIDUE) } + + /// Accessor for the invalid residue, less the target residue. + /// + /// Note that because the error type is not parameterized by a checksum (it + /// holds the target residue but this doesn't help), the caller will need + /// to obtain the checksum from somewhere else in order to make use of this. + /// + /// Not public because [`Polynomial`] is a private type, and because the + /// subtraction will panic if this is called without checking has_data + /// on the FieldVecs. + pub(super) fn residue(&self) -> Polynomial { self.actual.clone() - &self.target } } #[cfg(feature = "std")] diff --git a/src/primitives/field.rs b/src/primitives/field.rs index 9ac5ef7e..76be36cc 100644 --- a/src/primitives/field.rs +++ b/src/primitives/field.rs @@ -2,6 +2,7 @@ //! Generic Field Traits +use core::convert::TryInto; use core::iter::{Skip, Take}; use core::{fmt, hash, iter, ops}; @@ -44,6 +45,11 @@ pub trait Field: /// A primitive element, i.e. a generator of the multiplicative group of the field. const GENERATOR: Self; + /// The smallest integer n such that 1 + ... + 1, n times, equals 0. + /// + /// If this is 0, this indicates that no such integer exists. + const CHARACTERISTIC: usize; + /// The order of the multiplicative group of the field. const MULTIPLICATIVE_ORDER: usize; @@ -56,6 +62,47 @@ pub trait Field: /// Computes the multiplicative inverse of an element. fn multiplicative_inverse(self) -> Self; + /// Takes the element times some integer. + fn muli(&self, mut n: i64) -> Self { + let base = if n >= 0 { + self.clone() + } else { + n *= -1; + self.clone().multiplicative_inverse() + }; + + let mut ret = Self::ZERO; + // Special case some particular characteristics + match Self::CHARACTERISTIC { + 1 => unreachable!("no field has characteristic 1"), + 2 => { + // Special-case 2 because it's easy and also the only characteristic used + // within the library. The compiler should prune away the other code. + if n % 2 == 0 { + Self::ZERO + } else { + self.clone() + } + } + x => { + // This is identical to powi below, but with * replaced by +. + if x > 0 { + n %= x as i64; + } + + let mut mask = x.next_power_of_two() as i64; + while mask > 0 { + ret += ret.clone(); + if n & mask != 0 { + ret += &base; + } + mask >>= 1; + } + ret + } + } + } + /// Takes the element to the power of some integer. fn powi(&self, mut n: i64) -> Self { let base = if n >= 0 { @@ -71,7 +118,7 @@ pub trait Field: while mask > 0 { ret *= ret.clone(); if n & mask != 0 { - ret *= base.clone(); + ret *= &base; } mask >>= 1; } @@ -107,7 +154,7 @@ pub trait Field: /// Trait describing a simple extension field (field obtained from another by /// adjoining one element). -pub trait ExtensionField: Field + From { +pub trait ExtensionField: Field + From + TryInto { /// The type of the base field. type BaseField: Field; diff --git a/src/primitives/fieldvec.rs b/src/primitives/fieldvec.rs index 566f6afa..73da10bf 100644 --- a/src/primitives/fieldvec.rs +++ b/src/primitives/fieldvec.rs @@ -107,6 +107,23 @@ impl FieldVec { #[inline] pub fn is_empty(&self) -> bool { self.len == 0 } + /// Reverses the contents of the vector in-place. + pub fn reverse(&mut self) { + self.assert_has_data(); + + #[cfg(not(feature = "alloc"))] + { + self.inner_a[..self.len].reverse(); + } + + #[cfg(feature = "alloc")] + if self.len > NO_ALLOC_MAX_LENGTH { + self.inner_v.reverse(); + } else { + self.inner_a[..self.len].reverse(); + } + } + /// Returns an immutable iterator over the elements in the vector. /// /// # Panics @@ -186,7 +203,48 @@ impl FieldVec { } } +impl Default for FieldVec { + fn default() -> Self { Self::new() } +} + impl FieldVec { + /// Constructs a new empty field vector. + pub fn new() -> Self { + FieldVec { + inner_a: Default::default(), + len: 0, + #[cfg(feature = "alloc")] + inner_v: Vec::new(), + } + } + + /// Constructs a new field vector with the given capacity. + pub fn with_capacity(cap: usize) -> Self { + #[cfg(not(feature = "alloc"))] + { + let mut ret = Self::new(); + ret.len = cap; + ret.assert_has_data(); + ret.len = 0; + ret + } + + #[cfg(feature = "alloc")] + if cap > NO_ALLOC_MAX_LENGTH { + let mut ret = Self::new(); + ret.inner_v = Vec::with_capacity(cap); + ret + } else { + Self::new() + } + } + + /// Pushes an item onto the end of the vector. + /// + /// Synonym for [`Self::push`] used to simplify code where a + /// [`FieldVec`] is used in place of a `VecDeque`. + pub fn push_back(&mut self, item: F) { self.push(item) } + /// Pushes an item onto the end of the vector. /// /// # Panics @@ -213,6 +271,38 @@ impl FieldVec { } } + /// Pops an item off the front of the vector. + /// + /// This operation is always O(n). + pub fn pop_front(&mut self) -> Option { + self.assert_has_data(); + if self.len == 0 { + return None; + } + + #[cfg(not(feature = "alloc"))] + { + // Not the most efficient algorithm, but it is safe code, + // easily seen to be correct, and is only used with very + // small vectors. + self.reverse(); + let ret = self.pop(); + self.reverse(); + ret + } + + #[cfg(feature = "alloc")] + if self.len > NO_ALLOC_MAX_LENGTH + 1 { + self.len -= 1; + Some(self.inner_v.remove(0)) + } else { + self.reverse(); + let ret = self.pop(); + self.reverse(); + ret + } + } + /// Pops an item off the end of the vector. /// /// # Panics diff --git a/src/primitives/gf32.rs b/src/primitives/gf32.rs index 4076e87e..8b15b10f 100644 --- a/src/primitives/gf32.rs +++ b/src/primitives/gf32.rs @@ -300,6 +300,7 @@ impl Field for Fe32 { const ZERO: Self = Fe32::Q; const ONE: Self = Fe32::P; const GENERATOR: Self = Fe32::Z; + const CHARACTERISTIC: usize = 2; const MULTIPLICATIVE_ORDER: usize = 31; const MULTIPLICATIVE_ORDER_FACTORS: &'static [usize] = &[1, 31]; diff --git a/src/primitives/gf32_ext.rs b/src/primitives/gf32_ext.rs index a074305a..f5e5bee5 100644 --- a/src/primitives/gf32_ext.rs +++ b/src/primitives/gf32_ext.rs @@ -37,6 +37,19 @@ impl From for Fe32Ext { } } +impl core::convert::TryFrom> for Fe32 { + type Error = (); + + fn try_from(ext: Fe32Ext) -> Result { + for elem in &ext.inner[1..] { + if *elem != Fe32::Q { + return Err(()); + } + } + Ok(ext.inner[0]) + } +} + impl fmt::Debug for Fe32Ext { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(self, f) } } @@ -150,6 +163,8 @@ impl Field for Fe1024 { /// A generator of the field. const GENERATOR: Self = Self::new([Fe32::P, Fe32::H]); + const CHARACTERISTIC: usize = 2; + /// The order of the multiplicative group of the field. /// /// This constant also serves as a compile-time check that we can count @@ -235,6 +250,8 @@ impl Field for Fe32768 { /// The one element of the field. const ONE: Self = Self::new([Fe32::P, Fe32::Q, Fe32::Q]); + const CHARACTERISTIC: usize = 2; + // Chosen somewhat arbitrarily, by just guessing values until one came // out with the correct order. /// A generator of the field. diff --git a/src/primitives/lfsr.rs b/src/primitives/lfsr.rs new file mode 100644 index 00000000..8bb8b977 --- /dev/null +++ b/src/primitives/lfsr.rs @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: MIT + +//! Linear-Feedback Shift Registers +//! +//! A core part of our error-correction algorithm is the Berlekamp-Massey algorithm +//! for finding shift registers. A shift register is a collection of values along +//! with a rule (a particular linear combination) used to generate the next value. +//! When the next value is generated, it is added to the end and everything shifted +//! to the left, with the first value removed from the register and returned. +//! +//! For example, any linear recurrence relation, such as that for the Fibonacci +//! numbers, can be described as a shift register (`a_n = a_{n-1} + a_{n-2}`). +//! +//! This module contains the general Berlekamp-Massey algorithm, from Massey's +//! 1969 paper, implemented over a generic field. + +#[cfg(feature = "alloc")] +use alloc::collections::VecDeque; + +use super::{Field, FieldVec, Polynomial}; + +/// An iterator which returns the output of a linear-feedback-shift register +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct LfsrIter { + #[cfg(feature = "alloc")] + contents: VecDeque, + #[cfg(not(feature = "alloc"))] + contents: FieldVec, + /// The coefficients are internally represented as a polynomial so that + /// they can be returned as such for use in error correction. + /// + /// However, they really aren't a polynomial but rather a list of + /// coefficients of a linear transformation. Within the algorithm + /// they are always treated as a FieldVec, by calling `self.coeffs.as_inner`. + coeffs: Polynomial, +} + +impl LfsrIter { + /// Accessor for the coefficients used to compute the next element. + pub fn coefficients(&self) -> &[F] { &self.coeffs.as_inner()[1..] } + + /// Accessor for the coefficients used to compute the next element. + pub(super) fn coefficient_polynomial(&self) -> &Polynomial { &self.coeffs } + + /// Create a minimal LFSR iterator that generates a set of initial + /// contents, using Berlekamp's algorithm. + /// + /// # Panics + /// + /// Panics if given an empty list of initial contents. + pub fn berlekamp_massey(initial_contents: &[F]) -> LfsrIter { + assert_ne!(initial_contents.len(), 0, "cannot create a LFSR with no initial contents"); + + // Step numbers taken from Massey 1969 "Shift-register synthesis and BCH decoding" + // PDF: https://crypto.stanford.edu/~mironov/cs359/massey.pdf + // + // The notation in that paper is super confusing. It uses polynomials in + // `D`, uses `x` as an integer (the difference between the length of the + // connection polynomial and the length of the previous connection + // polynomial), uses `n` as a constant bound and `N` as a counter up to `n`. + // + // It also manually accounts for various values which are implicitly + // always equal to the lengths of polynomials. + + // Step 1 (init) + // `conn` and `last_conn` are `C(D)` and `B(D)` respectively, in BE order. + let mut conn = FieldVec::::with_capacity(1 + initial_contents.len()); + let mut old_conn = FieldVec::::with_capacity(1 + initial_contents.len()); + let mut old_d = F::ONE; // `b` in the paper + let mut x = 1; + + conn.push(F::ONE); + old_conn.push(F::ONE); + + // Step 2-6 (loop) + for n in 0..initial_contents.len() { + assert_eq!(conn[0], F::ONE, "we always maintain a monic polynomial"); + // Step 2 + // Compute d = s_n + sum C_i s_{n - i}, which is the difference between + // what our current LSFR computes and the actual next initial value. + // Since we always have C_0 = 1 we can compute this as sum C_i s_{n-i} + // for all i ranging from 0 to the length of C. + let d = conn + .iter() + .cloned() + .zip(initial_contents.iter().take(1 + n).rev()) + .map(|(a, b)| a * b) + .sum::(); + + if d == F::ZERO { + // Step 3: if d == 0, i.e. we correctly computed the next value, + // just increase our shift and iterate. + x += 1; + } else { + let db_inv = d.clone() / &old_d; + assert_eq!(db_inv.clone() * &old_d, d, "tried to compute {:?}/{:?}", d, old_d); + // If d != 0, we need to adjust our connection polynomial, which we do + // by subtracting a shifted multiplied version of our "old" connection + // polynomial. + // + // Here the "old" polynomial is the one we had before the last length + // change. The algorithm in the paper determines whether a length change + // is needed via the auxiliary variable L, which is initially set to 0 + // and then set to L <- n + 1 - L each time we increase the length. + // + // By an annoying recursive argument it can be shown that L, thus set, + // is always equal to `conn.len()`. This assignment corresponds to a + // length increase exactly when `L < n + 1 - L` or `2L <= n`, so the + // algorithm determines when a length increase is needed by comparing + // 2L to n. + // + // This is all very clever but entirely pointless and doesn't even show + // up in the proof of the algorithm (which instead has the English text + // "if a change in length is needed"). Instead we can use x and a little + // bit of arithmetic to directly compute the change in length and decide + // whether it is > 0. + let poly_add_length = old_conn.len() + x; + if poly_add_length <= conn.len() { + // Step 4 + for i in 0..old_conn.len() { + conn[i + x] -= db_inv.clone() * &old_conn[i]; + } + x += 1; + } else { + // Step 5 + let tmp = conn.clone(); + for _ in conn.len()..poly_add_length { + conn.push(F::ZERO); + } + for i in 0..old_conn.len() { + conn[i + x] -= db_inv.clone() * &old_conn[i]; + } + old_conn = tmp; + old_d = d; + x = 1; + } + } + } + // The connection polynomial has an initial monic term. For use as a LFSR we + // need to be a bit careful about this, since it is implicit in the formula + // for generating output from the shift register. So e.g. when generating + // our initial contents we use `conn.len() - 1` to get "the number of nontrivial + // coefficients", and in self.coefficients() we skip the monic term. + // + // In fact, if the purpose of this type were just to be a LFSR-based iterator, + // we could drop the monic term entirely. But since for error correction we + // instead want to extract the connection polynomial and treat it as an actual + // polynomial, we need to keep it. + + // Copy conn.len() (less the monic term) initial elements into the LSFR. + let contents = initial_contents.iter().take(conn.len() - 1).cloned().collect(); + LfsrIter { contents, coeffs: conn.into() } + } +} + +impl Iterator for LfsrIter { + type Item = F; + fn next(&mut self) -> Option { + debug_assert_eq!(self.contents.len(), self.coefficients().len()); + + let next = self + .coefficients() + .iter() + .zip(self.contents.iter().rev()) + .map(|(a, b)| a.clone() * b) + .sum(); + + let ret = self.contents.pop_front(); + self.contents.push_back(next); + ret // will always be Some + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Fe32; + + #[test] + fn berlekamp_massey_constant() { + for elem in LfsrIter::berlekamp_massey(&[Fe32::ONE, Fe32::ONE]).take(10) { + assert_eq!(elem, Fe32::ONE); + } + + for elem in LfsrIter::berlekamp_massey(&[Fe32::J, Fe32::J]).take(10) { + assert_eq!(elem, Fe32::J); + } + + // If we just give B-M a *single* element, it'll use that as the connection + // polynomial and return a series of increasing powers of that element. + let mut expect = Fe32::J; + for elem in LfsrIter::berlekamp_massey(&[Fe32::J]).take(10) { + assert_eq!(elem, expect); + expect *= Fe32::J; + } + } + + #[test] + fn berlekamp_massey_fibonacci() { + for elem in LfsrIter::berlekamp_massey(&[Fe32::P, Fe32::P]).take(10) { + assert_eq!(elem, Fe32::P); + } + + // In a characteristic-2 field we can only really generate the parity of + // the fibonnaci sequence, but that in itself is kinda interesting. + let parities: Vec<_> = + LfsrIter::berlekamp_massey(&[Fe32::P, Fe32::P, Fe32::Q]).take(10).collect(); + assert_eq!( + parities, + [ + Fe32::P, + Fe32::P, + Fe32::Q, + Fe32::P, + Fe32::P, + Fe32::Q, + Fe32::P, + Fe32::P, + Fe32::Q, + Fe32::P + ], + ); + } + + #[test] + fn berlekamp_massey() { + // A few test vectors that I was able to trigger interesting coverage + // with using the fuzzer. + + // Does a length change of more than 1 in a single iteration. + LfsrIter::berlekamp_massey(&[Fe32::Q, Fe32::P]).take(10).count(); + + // Causes old_conn.len + x to be less than conn.len (so naively subtracting + // these to check for a length increase will trigger an overflow). Hits the + // the "2L <= N" path with x != L. + LfsrIter::berlekamp_massey(&[Fe32::Q, Fe32::Y, Fe32::H]).take(10).count(); + + // Hits the the "2L <= N" path with x != L, without overflowing subtraction + // as in the above vector. + LfsrIter::berlekamp_massey(&[Fe32::Y, Fe32::H, Fe32::Q, Fe32::Q]).take(10).count(); + + // Triggers a length change with x != n + 1 - ell. The reason you might expect + // this is that ell is initially set to 0, then re-set to (n + 1 - ell) on each + // length change, i.e. it is a "count of how much n+1 increased since the last + // length change". + // + // Meanwhile, x is incremented on each iteration but reset to 1 on each length + // change. These assignment patterns sound very similar, but they are not the + // same, because the initial values and +1s are not the same. + LfsrIter::berlekamp_massey(&[Fe32::P, Fe32::P, Fe32::Y, Fe32::Q, Fe32::Q]).take(10).count(); + } +} diff --git a/src/primitives/mod.rs b/src/primitives/mod.rs index a4d803fa..06b16df4 100644 --- a/src/primitives/mod.rs +++ b/src/primitives/mod.rs @@ -12,6 +12,7 @@ pub mod gf32; pub mod gf32_ext; pub mod hrp; pub mod iter; +mod lfsr; mod polynomial; pub mod segwit; @@ -19,6 +20,7 @@ use checksum::{Checksum, PackedNull}; use field::impl_ops_for_fe; pub use field::{Bech32Field, ExtensionField, Field}; use fieldvec::FieldVec; +pub use lfsr::LfsrIter; use polynomial::Polynomial; use crate::{Fe1024, Fe32}; @@ -60,7 +62,7 @@ impl Checksum for Bech32 { type CorrectionField = Fe1024; const ROOT_GENERATOR: Self::CorrectionField = Fe1024::new([Fe32::P, Fe32::X]); - const ROOT_EXPONENTS: core::ops::RangeInclusive = 997..=999; + const ROOT_EXPONENTS: core::ops::RangeInclusive = 24..=26; const CODE_LENGTH: usize = 1023; const CHECKSUM_LENGTH: usize = 6; @@ -73,7 +75,7 @@ impl Checksum for Bech32m { type CorrectionField = Fe1024; const ROOT_GENERATOR: Self::CorrectionField = Fe1024::new([Fe32::P, Fe32::X]); - const ROOT_EXPONENTS: core::ops::RangeInclusive = 997..=999; + const ROOT_EXPONENTS: core::ops::RangeInclusive = 24..=26; const CODE_LENGTH: usize = 1023; const CHECKSUM_LENGTH: usize = 6; diff --git a/src/primitives/polynomial.rs b/src/primitives/polynomial.rs index 04860020..c202838f 100644 --- a/src/primitives/polynomial.rs +++ b/src/primitives/polynomial.rs @@ -2,7 +2,7 @@ //! Polynomials over Finite Fields -use core::{fmt, iter, ops, slice}; +use core::{cmp, fmt, iter, ops, slice}; use super::checksum::PackedFe32; use super::{ExtensionField, Field, FieldVec}; @@ -17,16 +17,14 @@ pub struct Polynomial { } impl PartialEq for Polynomial { - fn eq(&self, other: &Self) -> bool { - self.inner[..self.degree()] == other.inner[..other.degree()] - } + fn eq(&self, other: &Self) -> bool { self.coefficients() == other.coefficients() } } impl Eq for Polynomial {} impl Polynomial { pub fn from_residue(residue: R) -> Self { - (0..R::WIDTH).rev().map(|i| Fe32(residue.unpack(i))).collect() + (0..R::WIDTH).map(|i| Fe32(residue.unpack(i))).collect() } } impl Polynomial { @@ -58,9 +56,16 @@ impl Polynomial { debug_assert_ne!(self.inner.len(), 0, "polynomials never have no terms"); let degree_without_leading_zeros = self.inner.len() - 1; let leading_zeros = self.inner.iter().rev().take_while(|el| **el == F::ZERO).count(); - degree_without_leading_zeros - leading_zeros + degree_without_leading_zeros.saturating_sub(leading_zeros) } + /// Accessor for the coefficients of the polynomial, in "little endian" order. + /// + /// # Panics + /// + /// Panics if [`Self::has_data`] is false. + pub fn coefficients(&self) -> &[F] { &self.inner[..self.degree() + 1] } + /// An iterator over the coefficients of the polynomial. /// /// Yields value in "little endian" order; that is, the constant term is returned first. @@ -70,7 +75,7 @@ impl Polynomial { /// Panics if [`Self::has_data`] is false. pub fn iter(&self) -> slice::Iter { self.assert_has_data(); - self.inner.iter() + self.coefficients().iter() } /// The leading term of the polynomial. @@ -89,6 +94,11 @@ impl Polynomial { /// factor of the polynomial. pub fn zero_is_root(&self) -> bool { self.inner.is_empty() || self.leading_term() == F::ZERO } + /// Computes the formal derivative of the polynomial + pub fn formal_derivative(&self) -> Self { + self.iter().enumerate().map(|(n, fe)| fe.muli(n as i64)).skip(1).collect() + } + /// Helper function to add leading 0 terms until the polynomial has a specified /// length. fn zero_pad_up_to(&mut self, len: usize) { @@ -128,6 +138,56 @@ impl Polynomial { } } + /// Evaluate the polynomial at a given element. + pub fn evaluate>(&self, elem: &E) -> E { + let mut res = E::ZERO; + for fe in self.iter().rev() { + res *= elem; + res += E::from(fe.clone()); + } + res + } + + /// TODO + pub fn convolution(&self, syndromes: &Self) -> Self { + let mut ret = FieldVec::new(); + let terms = (1 + syndromes.inner.len()).saturating_sub(1 + self.degree()); + if terms == 0 { + ret.push(F::ZERO); + return Self::from(ret); + } + + let n = 1 + self.degree(); + for idx in 0..terms { + ret.push( + (0..n).map(|i| self.inner[n - i - 1].clone() * &syndromes.inner[idx + i]).sum(), + ); + } + Self::from(ret) + } + + /// Multiplies two polynomials modulo x^d, for some given `d`. + /// + /// Can be used to simply multiply two polynomials, by passing `usize::MAX` or + /// some other suitably large number as `d`. + pub fn mul_mod_x_d(&self, other: &Self, d: usize) -> Self { + if d == 0 { + return Self { inner: FieldVec::new() }; + } + + let sdeg = self.degree(); + let odeg = other.degree(); + + let convolution_product = |exp: usize| { + let sidx = exp.saturating_sub(sdeg); + let eidx = cmp::min(exp, odeg); + (sidx..=eidx).map(|i| self.inner[exp - i].clone() * &other.inner[i]).sum() + }; + + let max_n = cmp::min(sdeg + odeg + 1, d - 1); + (0..=max_n).map(convolution_product).collect() + } + /// Given a BCH generator polynomial, find an element alpha that maximizes the /// consecutive range i..j such that `alpha^i `through `alpha^j` are all roots /// of the polynomial. @@ -456,4 +516,40 @@ mod tests { panic!("Unexpected generator {}", elem); } } + + #[test] + fn mul_mod() { + let x_minus_1: Polynomial<_> = [Fe32::P, Fe32::P].iter().copied().collect(); + assert_eq!( + x_minus_1.mul_mod_x_d(&x_minus_1, 3), + [Fe32::P, Fe32::Q, Fe32::P].iter().copied().collect(), + ); + assert_eq!(x_minus_1.mul_mod_x_d(&x_minus_1, 2), [Fe32::P].iter().copied().collect(),); + } + + #[test] + #[cfg(feature = "alloc")] // needed since `mul_mod_x_d` produces extra 0 coefficients + fn factor_then_mul() { + let bech32_poly: Polynomial = { + use Fe32 as F; + [F::J, F::A, F::_4, F::_5, F::K, F::A, F::P] + } + .iter() + .copied() + .collect(); + + let bech32_poly_lift = Polynomial { inner: bech32_poly.inner.lift() }; + + let factors = bech32_poly + .find_nonzero_distinct_roots(Fe1024::GENERATOR) + .map(|idx| Fe1024::GENERATOR.powi(idx as i64)) + .map(|root| [root, Fe1024::ONE].iter().copied().collect::>()) + .collect::>(); + + let product = factors.iter().fold( + Polynomial::with_monic_leading_term(&[]), + |acc: Polynomial<_>, factor: &Polynomial<_>| acc.mul_mod_x_d(factor, 100), + ); + assert_eq!(bech32_poly_lift, product); + } }