Skip to content

Commit

Permalink
Better enum deserialization (#199)
Browse files Browse the repository at this point in the history
Together these fix #194 and #145

* NoVariantMatchedWithCauses deser error variant

Addresses #194

This helps avoid any type/group choice from eating any errors on
variants as instead of receiving a NoVariantMatched you will receive one
with errors as to why each variant failed.

This drastically helps debugging and also works for nested choices as
well.

* Avoid try-all on enums with non-overlapping CBOR

Fixes #145

When all variants have non-overlapping first CBOR type we can avoid
brute-force trying all possible variants for type/group choices and
instead branch on raw.cbor_type() to only try the variant that makes
sense.
  • Loading branch information
rooooooooob authored Jul 19, 2023
1 parent e25dbe7 commit 4eb8f4a
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 102 deletions.
312 changes: 225 additions & 87 deletions src/generation.rs

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions src/intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,16 @@ impl EnumVariant {
}
}

pub fn cbor_types(&self, types: &IntermediateTypes) -> Vec<CBORType> {
match &self.data {
EnumVariantData::RustType(ty) => ty.cbor_types(types),
EnumVariantData::Inlined(record) => match record.rep {
Representation::Array => vec![CBORType::Array],
Representation::Map => vec![CBORType::Map],
},
}
}

// Can only be used on RustType variants, panics otherwise.
// So don't call this when we're embedding the variant types
pub fn rust_type(&self) -> &RustType {
Expand Down
19 changes: 9 additions & 10 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use cbor_event::Type as CBORType;
use std::collections::BTreeMap;

pub fn _cbor_type_code_str(cbor_type: CBORType) -> &'static str {
pub fn cbor_type_code_str(cbor_type: cbor_event::Type) -> &'static str {
match cbor_type {
CBORType::UnsignedInteger => "CBORType::UnsignedInteger",
CBORType::NegativeInteger => "CBORType::NegativeInteger",
CBORType::Bytes => "CBORType::Bytes",
CBORType::Text => "CBORType::Text",
CBORType::Array => "CBORType::Array",
CBORType::Map => "CBORType::Map",
CBORType::Tag => "CBORType::Tag",
CBORType::Special => "CBORType::Special",
cbor_event::Type::UnsignedInteger => "cbor_event::Type::UnsignedInteger",
cbor_event::Type::NegativeInteger => "cbor_event::Type::NegativeInteger",
cbor_event::Type::Bytes => "cbor_event::Type::Bytes",
cbor_event::Type::Text => "cbor_event::Type::Text",
cbor_event::Type::Array => "cbor_event::Type::Array",
cbor_event::Type::Map => "cbor_event::Type::Map",
cbor_event::Type::Tag => "cbor_event::Type::Tag",
cbor_event::Type::Special => "cbor_event::Type::Special",
}
}

Expand Down
27 changes: 22 additions & 5 deletions static/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub enum DeserializeFailure {
InvalidStructure(Box<dyn std::error::Error>),
MandatoryFieldMissing(Key),
NoVariantMatched,
NoVariantMatchedWithCauses(Vec<DeserializeError>),
RangeCheck{
found: usize,
min: Option<isize>,
Expand Down Expand Up @@ -68,12 +69,12 @@ impl DeserializeError {
None => Self::new(location, self.failure),
}
}
}

impl std::error::Error for DeserializeError {}

impl std::fmt::Display for DeserializeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt_indent(&self, f: &mut std::fmt::Formatter<'_>, indent: u32) -> std::fmt::Result {
use std::fmt::Display;
for _ in 0..indent {
write!(f, "\t")?;
}
match &self.location {
Some(loc) => write!(f, "Deserialization failed in {} because: ", loc),
None => write!(f, "Deserialization: "),
Expand All @@ -97,6 +98,14 @@ impl std::fmt::Display for DeserializeError {
}
DeserializeFailure::MandatoryFieldMissing(key) => write!(f, "Mandatory field {} not found", key),
DeserializeFailure::NoVariantMatched => write!(f, "No variant matched"),
DeserializeFailure::NoVariantMatchedWithCauses(errs) => {
write!(f, "No variant matched. Failures:\n")?;
for e in errs {
e.fmt_indent(f, indent + 1)?;
write!(f, "\n")?;
}
Ok(())
},
DeserializeFailure::RangeCheck{ found, min, max } => match (min, max) {
(Some(min), Some(max)) => write!(f, "{} not in range {} - {}", found, min, max),
(Some(min), None) => write!(f, "{} not at least {}", found, min),
Expand All @@ -110,6 +119,14 @@ impl std::fmt::Display for DeserializeError {
}
}

impl std::error::Error for DeserializeError {}

impl std::fmt::Display for DeserializeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.fmt_indent(f, 0)
}
}

impl From<DeserializeFailure> for DeserializeError {
fn from(failure: DeserializeFailure) -> DeserializeError {
DeserializeError {
Expand Down
4 changes: 4 additions & 0 deletions tests/core/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ c_enum = 3 / 1 / 4

type_choice = 0 / "hello world" / uint / text / bytes / #6.64([*uint])

non_overlapping_type_choice_all = uint / nint / text / bytes / #6.30("hello world") / [* uint] / { *text => uint }

non_overlapping_type_choice_some = uint / nint / text

enums = [
c_enum,
type_choice,
Expand Down
20 changes: 20 additions & 0 deletions tests/core/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,24 @@ mod tests {
let overlap2 = OverlappingInlined::new_overlapping_inlined2(5, "overlapping".into());
//deser_test(&overlap2);
}

#[test]
fn overlapping_type_choice_all() {
deser_test(&NonOverlappingTypeChoiceAll::U64(100));
deser_test(&NonOverlappingTypeChoiceAll::N64(10000));
deser_test(&NonOverlappingTypeChoiceAll::Text("Hello, World!".into()));
deser_test(&NonOverlappingTypeChoiceAll::Bytes(vec![0xBA, 0xAD, 0xF0, 0x0D]));
deser_test(&NonOverlappingTypeChoiceAll::Helloworld);
deser_test(&NonOverlappingTypeChoiceAll::ArrU64(vec![0, u64::MAX]));
deser_test(&NonOverlappingTypeChoiceAll::MapTextToU64(
BTreeMap::from([("two".into(), 2), ("four".into(), 4)]))
);
}

#[test]
fn overlapping_type_choice_some() {
deser_test(&NonOverlappingTypeChoiceSome::U64(100));
deser_test(&NonOverlappingTypeChoiceSome::N64(10000));
deser_test(&NonOverlappingTypeChoiceSome::Text("Hello, World!".into()));
}
}
4 changes: 4 additions & 0 deletions tests/preserve-encodings/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ string_16_32 = #6.7(text .size (16..32))

type_choice = 0 / "hello world" / uint / text / #6.16([*uint])

non_overlapping_type_choice_all = uint / nint / text / bytes / #6.13("hello world") / [* uint] / { *text => uint }

non_overlapping_type_choice_some = uint / nint / text

c_enum = 3 / 1 / 4

enums = [
Expand Down
73 changes: 73 additions & 0 deletions tests/preserve-encodings/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,79 @@ mod tests {
}
}

#[test]
fn non_overlapping_type_choice_some() {
let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight];
let str_11_encodings = vec![
StringLenSz::Len(Sz::One),
StringLenSz::Len(Sz::Inline),
StringLenSz::Indefinite(vec![(5, Sz::Two), (6, Sz::One)]),
StringLenSz::Indefinite(vec![(2, Sz::Inline), (0, Sz::Inline), (9, Sz::Four)]),
];
for str_enc in &str_11_encodings {
for def_enc in &def_encodings {
let irregular_bytes_uint = cbor_int(0, *def_enc);
let irregular_bytes_nint = cbor_int(-9, *def_enc);
let irregular_bytes_text = cbor_str_sz("abcdefghijk", str_enc.clone());
let irregular_uint = NonOverlappingTypeChoiceSome::from_cbor_bytes(&irregular_bytes_uint).unwrap();
assert_eq!(irregular_bytes_uint, irregular_uint.to_cbor_bytes());
let irregular_nint = NonOverlappingTypeChoiceSome::from_cbor_bytes(&irregular_bytes_nint).unwrap();
assert_eq!(irregular_bytes_nint, irregular_nint.to_cbor_bytes());
let irregular_text = NonOverlappingTypeChoiceSome::from_cbor_bytes(&irregular_bytes_text).unwrap();
assert_eq!(irregular_bytes_text, irregular_text.to_cbor_bytes());
}
}
}

#[test]
fn non_overlapping_type_choice_all() {
let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight];
let str_11_encodings = vec![
StringLenSz::Len(Sz::One),
StringLenSz::Len(Sz::Inline),
StringLenSz::Indefinite(vec![(5, Sz::Two), (6, Sz::One)]),
StringLenSz::Indefinite(vec![(2, Sz::Inline), (0, Sz::Inline), (9, Sz::Four)]),
];
for str_enc in &str_11_encodings {
for def_enc in &def_encodings {
let irregular_bytes_uint = cbor_int(0, *def_enc);
let irregular_bytes_nint = cbor_int(-9, *def_enc);
let irregular_bytes_text = cbor_str_sz("abcdefghijk", str_enc.clone());
let irregular_bytes_bytes = cbor_bytes_sz(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], str_enc.clone());
let irregular_bytes_hello_world = vec![
cbor_tag_sz(13, *def_enc),
cbor_str_sz("hello world", str_enc.clone())
].into_iter().flatten().clone().collect::<Vec<u8>>();
let irregular_bytes_arr = vec![
arr_sz(2, *def_enc),
cbor_int(1, *def_enc),
cbor_int(3, *def_enc),
].into_iter().flatten().clone().collect::<Vec<u8>>();
let irregular_bytes_map = vec![
map_sz(2, *def_enc),
cbor_str_sz("11111111111", str_enc.clone()),
cbor_int(1, *def_enc),
cbor_str_sz("33333333333", str_enc.clone()),
cbor_int(3, *def_enc),
].into_iter().flatten().clone().collect::<Vec<u8>>();
let irregular_uint = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_uint).unwrap();
assert_eq!(irregular_bytes_uint, irregular_uint.to_cbor_bytes());
let irregular_nint = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_nint).unwrap();
assert_eq!(irregular_bytes_nint, irregular_nint.to_cbor_bytes());
let irregular_text = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_text).unwrap();
assert_eq!(irregular_bytes_text, irregular_text.to_cbor_bytes());
let irregular_bytes = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_bytes).unwrap();
assert_eq!(irregular_bytes_bytes, irregular_bytes.to_cbor_bytes());
let irregular_hello_world = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_hello_world).unwrap();
assert_eq!(irregular_bytes_hello_world, irregular_hello_world.to_cbor_bytes());
let irregular_arr = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_arr).unwrap();
assert_eq!(irregular_bytes_arr, irregular_arr.to_cbor_bytes());
let irregular_map = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_map).unwrap();
assert_eq!(irregular_bytes_map, irregular_map.to_cbor_bytes());
}
}
}

#[test]
fn enums() {
let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight];
Expand Down

0 comments on commit 4eb8f4a

Please sign in to comment.