diff --git a/baby-bear/Cargo.toml b/baby-bear/Cargo.toml index b9b665fa..658b4889 100644 --- a/baby-bear/Cargo.toml +++ b/baby-bear/Cargo.toml @@ -13,6 +13,7 @@ serde = { version = "1.0", default-features = false, features = ["derive"] } p3-field-testing = { path = "../field-testing" } criterion = "0.5.1" rand_chacha = "0.3.1" +serde_json = "1.0.113" [[bench]] name = "inverse" diff --git a/baby-bear/src/baby_bear.rs b/baby-bear/src/baby_bear.rs index 9b9e71ed..0e73288f 100644 --- a/baby-bear/src/baby_bear.rs +++ b/baby-bear/src/baby_bear.rs @@ -8,7 +8,7 @@ use p3_field::{ }; use rand::distributions::{Distribution, Standard}; use rand::Rng; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; /// The Baby Bear prime const P: u32 = 0x78000001; @@ -34,7 +34,7 @@ const MONTY_MU: u32 = if cfg!(all(target_arch = "aarch64", target_feature = "neo const MONTY_MASK: u32 = ((1u64 << MONTY_BITS) - 1) as u32; /// The prime field `2^31 - 2^27 + 1`, a.k.a. the Baby Bear field. -#[derive(Copy, Clone, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] +#[derive(Copy, Clone, Default, Eq, Hash, PartialEq)] #[repr(transparent)] // `PackedBabyBearNeon` relies on this! pub struct BabyBear { // This is `pub(crate)` just for tests. If you're accessing `value` outside of those, you're @@ -89,6 +89,19 @@ impl Distribution for Standard { } } +impl Serialize for BabyBear { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_u32(self.as_canonical_u32()) + } +} + +impl<'de> Deserialize<'de> for BabyBear { + fn deserialize>(d: D) -> Result { + let val = u32::deserialize(d)?; + Ok(BabyBear::from_canonical_u32(val)) + } +} + const MONTY_ZERO: u32 = to_monty(0); const MONTY_ONE: u32 = to_monty(1); const MONTY_TWO: u32 = to_monty(2); @@ -493,6 +506,37 @@ mod tests { assert_eq!(m1.exp_u64(1725656503).exp_const_u64::<7>(), m1); assert_eq!(m2.exp_u64(1725656503).exp_const_u64::<7>(), m2); assert_eq!(f_2.exp_u64(1725656503).exp_const_u64::<7>(), f_2); + + let f_serialized = serde_json::to_string(&f).unwrap(); + let f_deserialized: F = serde_json::from_str(&f_serialized).unwrap(); + assert_eq!(f, f_deserialized); + + let f_1_serialized = serde_json::to_string(&f_1).unwrap(); + let f_1_deserialized: F = serde_json::from_str(&f_1_serialized).unwrap(); + let f_1_serialized_again = serde_json::to_string(&f_1_deserialized).unwrap(); + let f_1_deserialized_again: F = serde_json::from_str(&f_1_serialized_again).unwrap(); + assert_eq!(f_1, f_1_deserialized); + assert_eq!(f_1, f_1_deserialized_again); + + let f_2_serialized = serde_json::to_string(&f_2).unwrap(); + let f_2_deserialized: F = serde_json::from_str(&f_2_serialized).unwrap(); + assert_eq!(f_2, f_2_deserialized); + + let f_p_minus_1_serialized = serde_json::to_string(&f_p_minus_1).unwrap(); + let f_p_minus_1_deserialized: F = serde_json::from_str(&f_p_minus_1_serialized).unwrap(); + assert_eq!(f_p_minus_1, f_p_minus_1_deserialized); + + let f_p_minus_2_serialized = serde_json::to_string(&f_p_minus_2).unwrap(); + let f_p_minus_2_deserialized: F = serde_json::from_str(&f_p_minus_2_serialized).unwrap(); + assert_eq!(f_p_minus_2, f_p_minus_2_deserialized); + + let m1_serialized = serde_json::to_string(&m1).unwrap(); + let m1_deserialized: F = serde_json::from_str(&m1_serialized).unwrap(); + assert_eq!(m1, m1_deserialized); + + let m2_serialized = serde_json::to_string(&m2).unwrap(); + let m2_deserialized: F = serde_json::from_str(&m2_serialized).unwrap(); + assert_eq!(m2, m2_deserialized); } test_field!(crate::BabyBear);