diff --git a/src/generation.rs b/src/generation.rs index d7f5730..9b2e6f2 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -12,7 +12,7 @@ use crate::intermediate::{ RustRecord, RustStructCBORLen, RustStructType, RustType, ToWasmBoundaryOperations, VariantIdent, ROOT_SCOPE, }; -use crate::utils::convert_to_snake_case; +use crate::utils::{cbor_type_code_str, convert_to_snake_case}; #[derive(Debug, Clone)] struct SerializeConfig<'a> { @@ -907,14 +907,8 @@ impl GenerationScope { if cli.preserve_encodings { content.push_import("super::cbor_encodings", "*", None); } - if cli.preserve_encodings && cli.canonical_form { - content.push_import("cbor_event", "self", None); - } else { - content.push_import("cbor_event", "self", None).push_import( - "cbor_event::se", - "Serialize", - None, - ); + if !(cli.preserve_encodings && cli.canonical_form) { + content.push_import("cbor_event::se", "Serialize", None); } if *scope != *ROOT_SCOPE { content.push_import( @@ -5863,6 +5857,50 @@ fn make_enum_variant_return_if_deserialized( } } +fn make_inline_deser_code( + gen_scope: &mut GenerationScope, + types: &IntermediateTypes, + name: &RustIdent, + tag: Option, + record: &RustRecord, + enum_gen_info: &EnumVariantInRust, + cli: &Cli, +) -> DeserializationCode { + let mut variant_deser_code = generate_array_struct_deserialization( + gen_scope, types, name, record, tag, false, false, cli, + ); + add_deserialize_final_len_check( + &mut variant_deser_code.deser_code.content, + Some(record.rep), + record.cbor_len_info(types), + cli, + ); + // generate_constructor zips the expressions with the names in the enum_gen_info + // so just make sure we're in the same order as returned above + assert_eq!( + enum_gen_info.names.len(), + variant_deser_code.deser_ctor_fields.len() + + variant_deser_code.encoding_struct_ctor_fields.len() + ); + let ctor_exprs = variant_deser_code + .deser_ctor_fields + .into_iter() + .chain(variant_deser_code.encoding_struct_ctor_fields.into_iter()) + .zip(enum_gen_info.names.iter()) + .map(|((var, expr), name)| { + assert_eq!(var, *name); + expr + }) + .collect(); + enum_gen_info.generate_constructor( + &mut variant_deser_code.deser_code.content, + "Ok(", + ")", + Some(&ctor_exprs), + ); + variant_deser_code.deser_code +} + // Generates a general enum e.g. Foo { A(A), B(B), C(C) } for types A, B, C // if generate_deserialize_directly, don't generate deserialize_as_embedded_group() and just inline it within deserialize() // This is useful for type choicecs which don't have any enclosing array/map tags, and thus don't benefit from exposing a @@ -5945,7 +5983,31 @@ fn generate_enum( ); deser_impl }; - deser_body.line("let initial_position = raw.as_mut_ref().seek(SeekFrom::Current(0)).unwrap();"); + // We avoid checking ALL variants if we can figure it out by instead checking the type. + // This only works when the variants don't have first types in common. + let mut non_overlapping_types_match = { + // uses to_byte() instead of directly since Ord not implemented for cbor_event::Type + let mut first_types = BTreeSet::new(); + let mut duplicates = false; + for variant in variants.iter() { + for first_type in variant.cbor_types(types) { + if !first_types.insert(first_type.to_byte(0)) { + duplicates = true; + } + } + } + if duplicates { + None + } else { + let deser_covers_all_types = first_types.len() == 8; + Some((Block::new("match raw.cbor_type()?"), deser_covers_all_types)) + } + }; + if non_overlapping_types_match.is_none() { + deser_body + .line("let initial_position = raw.as_mut_ref().seek(SeekFrom::Current(0)).unwrap();") + .line("let mut errs = Vec::new();"); + } for variant in variants.iter() { let enum_gen_info = EnumVariantInRust::new(types, variant, rep, cli); let variant_var_name = variant.name_as_var(); @@ -6161,85 +6223,149 @@ fn generate_enum( ser_array_match_block.push_block(case_block); } // deserialize - // TODO: don't backtrack if variants begin with non-overlapping cbor types - // issue: https://github.com/dcSpark/cddl-codegen/issues/145 // TODO: how to detect when a greedy match won't work? (ie choice with choices in a choice possibly) - let mut return_if_deserialized = match &variant.data { - EnumVariantData::RustType(_) => { - let mut return_if_deserialized = make_enum_variant_return_if_deserialized( - gen_scope, - types, - variant, - enum_gen_info.types.is_empty(), - deser_body, - cli, - ); - let names_without_outer = enum_gen_info.names_without_outer(); - if names_without_outer.is_empty() { - return_if_deserialized - .line(format!("Ok(()) => return Ok({}::{}),", name, variant.name)); - } else { - enum_gen_info.generate_constructor( - &mut return_if_deserialized, - &if names_without_outer.len() > 1 { - format!("Ok(({})) => return Ok(", names_without_outer.join(", ")) + match non_overlapping_types_match.as_mut() { + Some((deser_type_match, _deser_covers_all_types)) => { + let variant_deser_code = match &variant.data { + EnumVariantData::RustType(ty) => { + let var_names_str = if cli.preserve_encodings { + encoding_var_names_str(types, &variant.name_as_var(), ty, cli) } else { - format!("Ok({}) => return Ok(", names_without_outer.join(", ")) - }, - "),", - None, - ); + variant.name_as_var() + }; + let mut variant_deser_code = gen_scope.generate_deserialize( + types, + (variant.rust_type()).into(), + DeserializeBeforeAfter::new( + &format!("let {var_names_str} = "), + ";", + false, + ), + DeserializeConfig::new(&variant.name_as_var()), + cli, + ); + let names_without_outer = enum_gen_info.names_without_outer(); + // we can avoid this ugly block and directly do it as a line possibly + if variant_deser_code.content.as_single_line().is_some() + && names_without_outer.len() == 1 + { + variant_deser_code = gen_scope.generate_deserialize( + types, + (variant.rust_type()).into(), + DeserializeBeforeAfter::new( + &format!("Ok({}::{}(", name, variant.name), + "))", + false, + ), + DeserializeConfig::new(&variant.name_as_var()), + cli, + ); + } else { + if names_without_outer.is_empty() { + variant_deser_code + .content + .line(&format!("Ok({}::{})", name, variant.name)); + } else { + enum_gen_info.generate_constructor( + &mut variant_deser_code.content, + "Ok(", + ")", + None, + ); + } + } + variant_deser_code + } + EnumVariantData::Inlined(record) => make_inline_deser_code( + gen_scope, + types, + name, + tag, + record, + &enum_gen_info, + cli, + ), + }; + let cbor_types_str = variant + .cbor_types(types) + .into_iter() + .map(cbor_type_code_str) + .collect::>() + .join("|"); + match variant_deser_code.content.as_single_line() { + Some(single_line) => { + deser_type_match.line(format!("{cbor_types_str} => {single_line},")); + } + None => { + let mut match_arm = Block::new(format!("{cbor_types_str} =>")); + variant_deser_code.add_to(&mut match_arm); + deser_type_match.push_block(match_arm); + } } - return_if_deserialized } - EnumVariantData::Inlined(record) => { - let mut variant_deser_code = generate_array_struct_deserialization( - gen_scope, types, name, record, tag, false, false, cli, - ); - add_deserialize_final_len_check( - &mut variant_deser_code.deser_code.content, - Some(record.rep), - record.cbor_len_info(types), - cli, - ); - // generate_constructor zips the expressions with the names in the enum_gen_info - // so just make sure we're in the same order as returned above - assert_eq!( - enum_gen_info.names.len(), - variant_deser_code.deser_ctor_fields.len() - + variant_deser_code.encoding_struct_ctor_fields.len() - ); - let ctor_exprs = variant_deser_code - .deser_ctor_fields - .into_iter() - .chain(variant_deser_code.encoding_struct_ctor_fields.into_iter()) - .zip(enum_gen_info.names.iter()) - .map(|((var, expr), name)| { - assert_eq!(var, *name); - expr - }) - .collect(); - enum_gen_info.generate_constructor( - &mut variant_deser_code.deser_code.content, - "Ok(", - ")", - Some(&ctor_exprs), - ); - let mut variant_deser = - Block::new("match (|raw: &mut Deserializer<_>| -> Result<_, DeserializeError>"); - variant_deser.after(")(raw)"); - variant_deser.push_all(variant_deser_code.deser_code.content); - deser_body.push_block(variant_deser); - // can't chain blocks so we just put them one after the other - let mut return_if_deserialized = Block::new(""); - return_if_deserialized.line("Ok(variant) => return Ok(variant),"); - return_if_deserialized + None => { + let mut return_if_deserialized = match &variant.data { + EnumVariantData::RustType(_) => { + let mut return_if_deserialized = make_enum_variant_return_if_deserialized( + gen_scope, + types, + variant, + enum_gen_info.types.is_empty(), + deser_body, + cli, + ); + let names_without_outer = enum_gen_info.names_without_outer(); + if names_without_outer.is_empty() { + return_if_deserialized + .line(format!("Ok(()) => return Ok({}::{}),", name, variant.name)); + } else { + enum_gen_info.generate_constructor( + &mut return_if_deserialized, + &if names_without_outer.len() > 1 { + format!( + "Ok(({})) => return Ok(", + names_without_outer.join(", ") + ) + } else { + format!("Ok({}) => return Ok(", names_without_outer.join(", ")) + }, + "),", + None, + ); + } + return_if_deserialized + } + EnumVariantData::Inlined(record) => { + let variant_deser_code = make_inline_deser_code( + gen_scope, + types, + name, + tag, + record, + &enum_gen_info, + cli, + ); + let mut variant_deser = Block::new( + "match (|raw: &mut Deserializer<_>| -> Result<_, DeserializeError>", + ); + variant_deser.after(")(raw)"); + variant_deser.push_all(variant_deser_code.content); + deser_body.push_block(variant_deser); + // can't chain blocks so we just put them one after the other + let mut return_if_deserialized = Block::new(""); + return_if_deserialized.line("Ok(variant) => return Ok(variant),"); + return_if_deserialized + } + }; + let mut variant_deser_failed_block = Block::new("Err(e) =>"); + variant_deser_failed_block + .line(format!("errs.push(e.annotate(\"{}\"));", variant.name)) + .line("raw.as_mut_ref().seek(SeekFrom::Start(initial_position)).unwrap();"); + return_if_deserialized.push_block(variant_deser_failed_block); + return_if_deserialized.after(";"); + deser_body.push_block(return_if_deserialized); } - }; - return_if_deserialized - .line("Err(_) => raw.as_mut_ref().seek(SeekFrom::Start(initial_position)).unwrap(),"); - return_if_deserialized.after(";"); - deser_body.push_block(return_if_deserialized); + } } ser_func.push_block(ser_array_match_block); ser_impl.push_fn(ser_func); @@ -6253,9 +6379,21 @@ fn generate_enum( // This can cause issues when there are overlapping (CBOR field-wise) variants inlined here. // Issue: https://github.com/dcSpark/cddl-codegen/issues/175 add_deserialize_final_len_check(deser_body, rep, RustStructCBORLen::Fixed(0), cli); - deser_body.line(&format!( - "Err(DeserializeError::new(\"{name}\", DeserializeFailure::NoVariantMatched))" - )); + match non_overlapping_types_match { + Some((mut deser_type_match, deser_covers_all_types)) => { + if !deser_covers_all_types { + deser_type_match.line(format!( + "_ => Err(DeserializeError::new(\"{name}\", DeserializeFailure::NoVariantMatched))," + )); + } + deser_body.push_block(deser_type_match); + } + None => { + deser_body.line(&format!( + "Err(DeserializeError::new(\"{name}\", DeserializeFailure::NoVariantMatchedWithCauses(errs)))" + )); + } + } if cli.annotate_fields { deser_func.push_block(error_annotator); } diff --git a/src/intermediate.rs b/src/intermediate.rs index 75aa9fa..14a9eaa 100644 --- a/src/intermediate.rs +++ b/src/intermediate.rs @@ -1941,6 +1941,16 @@ impl EnumVariant { } } + pub fn cbor_types(&self, types: &IntermediateTypes) -> Vec { + match &self.data { + EnumVariantData::RustType(ty) => ty.cbor_types(types), + EnumVariantData::Inlined(record) => match record.rep { + Representation::Array => vec![CBORType::Array], + Representation::Map => vec![CBORType::Map], + }, + } + } + // Can only be used on RustType variants, panics otherwise. // So don't call this when we're embedding the variant types pub fn rust_type(&self) -> &RustType { diff --git a/src/utils.rs b/src/utils.rs index 1707e6e..53c6515 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,16 +1,15 @@ -use cbor_event::Type as CBORType; use std::collections::BTreeMap; -pub fn _cbor_type_code_str(cbor_type: CBORType) -> &'static str { +pub fn cbor_type_code_str(cbor_type: cbor_event::Type) -> &'static str { match cbor_type { - CBORType::UnsignedInteger => "CBORType::UnsignedInteger", - CBORType::NegativeInteger => "CBORType::NegativeInteger", - CBORType::Bytes => "CBORType::Bytes", - CBORType::Text => "CBORType::Text", - CBORType::Array => "CBORType::Array", - CBORType::Map => "CBORType::Map", - CBORType::Tag => "CBORType::Tag", - CBORType::Special => "CBORType::Special", + cbor_event::Type::UnsignedInteger => "cbor_event::Type::UnsignedInteger", + cbor_event::Type::NegativeInteger => "cbor_event::Type::NegativeInteger", + cbor_event::Type::Bytes => "cbor_event::Type::Bytes", + cbor_event::Type::Text => "cbor_event::Type::Text", + cbor_event::Type::Array => "cbor_event::Type::Array", + cbor_event::Type::Map => "cbor_event::Type::Map", + cbor_event::Type::Tag => "cbor_event::Type::Tag", + cbor_event::Type::Special => "cbor_event::Type::Special", } } diff --git a/static/error.rs b/static/error.rs index 74d64c4..082d809 100644 --- a/static/error.rs +++ b/static/error.rs @@ -34,6 +34,7 @@ pub enum DeserializeFailure { InvalidStructure(Box), MandatoryFieldMissing(Key), NoVariantMatched, + NoVariantMatchedWithCauses(Vec), RangeCheck{ found: usize, min: Option, @@ -68,12 +69,12 @@ impl DeserializeError { None => Self::new(location, self.failure), } } -} - -impl std::error::Error for DeserializeError {} -impl std::fmt::Display for DeserializeError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt_indent(&self, f: &mut std::fmt::Formatter<'_>, indent: u32) -> std::fmt::Result { + use std::fmt::Display; + for _ in 0..indent { + write!(f, "\t")?; + } match &self.location { Some(loc) => write!(f, "Deserialization failed in {} because: ", loc), None => write!(f, "Deserialization: "), @@ -97,6 +98,14 @@ impl std::fmt::Display for DeserializeError { } DeserializeFailure::MandatoryFieldMissing(key) => write!(f, "Mandatory field {} not found", key), DeserializeFailure::NoVariantMatched => write!(f, "No variant matched"), + DeserializeFailure::NoVariantMatchedWithCauses(errs) => { + write!(f, "No variant matched. Failures:\n")?; + for e in errs { + e.fmt_indent(f, indent + 1)?; + write!(f, "\n")?; + } + Ok(()) + }, DeserializeFailure::RangeCheck{ found, min, max } => match (min, max) { (Some(min), Some(max)) => write!(f, "{} not in range {} - {}", found, min, max), (Some(min), None) => write!(f, "{} not at least {}", found, min), @@ -110,6 +119,14 @@ impl std::fmt::Display for DeserializeError { } } +impl std::error::Error for DeserializeError {} + +impl std::fmt::Display for DeserializeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.fmt_indent(f, 0) + } +} + impl From for DeserializeError { fn from(failure: DeserializeFailure) -> DeserializeError { DeserializeError { diff --git a/tests/core/input.cddl b/tests/core/input.cddl index 879d462..a946f8e 100644 --- a/tests/core/input.cddl +++ b/tests/core/input.cddl @@ -36,6 +36,10 @@ c_enum = 3 / 1 / 4 type_choice = 0 / "hello world" / uint / text / bytes / #6.64([*uint]) +non_overlapping_type_choice_all = uint / nint / text / bytes / #6.30("hello world") / [* uint] / { *text => uint } + +non_overlapping_type_choice_some = uint / nint / text + enums = [ c_enum, type_choice, diff --git a/tests/core/tests.rs b/tests/core/tests.rs index 03007ce..15b2f40 100644 --- a/tests/core/tests.rs +++ b/tests/core/tests.rs @@ -219,4 +219,24 @@ mod tests { let overlap2 = OverlappingInlined::new_overlapping_inlined2(5, "overlapping".into()); //deser_test(&overlap2); } + + #[test] + fn overlapping_type_choice_all() { + deser_test(&NonOverlappingTypeChoiceAll::U64(100)); + deser_test(&NonOverlappingTypeChoiceAll::N64(10000)); + deser_test(&NonOverlappingTypeChoiceAll::Text("Hello, World!".into())); + deser_test(&NonOverlappingTypeChoiceAll::Bytes(vec![0xBA, 0xAD, 0xF0, 0x0D])); + deser_test(&NonOverlappingTypeChoiceAll::Helloworld); + deser_test(&NonOverlappingTypeChoiceAll::ArrU64(vec![0, u64::MAX])); + deser_test(&NonOverlappingTypeChoiceAll::MapTextToU64( + BTreeMap::from([("two".into(), 2), ("four".into(), 4)])) + ); + } + + #[test] + fn overlapping_type_choice_some() { + deser_test(&NonOverlappingTypeChoiceSome::U64(100)); + deser_test(&NonOverlappingTypeChoiceSome::N64(10000)); + deser_test(&NonOverlappingTypeChoiceSome::Text("Hello, World!".into())); + } } diff --git a/tests/preserve-encodings/input.cddl b/tests/preserve-encodings/input.cddl index 3938962..b4cbb20 100644 --- a/tests/preserve-encodings/input.cddl +++ b/tests/preserve-encodings/input.cddl @@ -30,6 +30,10 @@ string_16_32 = #6.7(text .size (16..32)) type_choice = 0 / "hello world" / uint / text / #6.16([*uint]) +non_overlapping_type_choice_all = uint / nint / text / bytes / #6.13("hello world") / [* uint] / { *text => uint } + +non_overlapping_type_choice_some = uint / nint / text + c_enum = 3 / 1 / 4 enums = [ diff --git a/tests/preserve-encodings/tests.rs b/tests/preserve-encodings/tests.rs index 504b2ef..7f5016d 100644 --- a/tests/preserve-encodings/tests.rs +++ b/tests/preserve-encodings/tests.rs @@ -366,6 +366,79 @@ mod tests { } } + #[test] + fn non_overlapping_type_choice_some() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + let str_11_encodings = vec![ + StringLenSz::Len(Sz::One), + StringLenSz::Len(Sz::Inline), + StringLenSz::Indefinite(vec![(5, Sz::Two), (6, Sz::One)]), + StringLenSz::Indefinite(vec![(2, Sz::Inline), (0, Sz::Inline), (9, Sz::Four)]), + ]; + for str_enc in &str_11_encodings { + for def_enc in &def_encodings { + let irregular_bytes_uint = cbor_int(0, *def_enc); + let irregular_bytes_nint = cbor_int(-9, *def_enc); + let irregular_bytes_text = cbor_str_sz("abcdefghijk", str_enc.clone()); + let irregular_uint = NonOverlappingTypeChoiceSome::from_cbor_bytes(&irregular_bytes_uint).unwrap(); + assert_eq!(irregular_bytes_uint, irregular_uint.to_cbor_bytes()); + let irregular_nint = NonOverlappingTypeChoiceSome::from_cbor_bytes(&irregular_bytes_nint).unwrap(); + assert_eq!(irregular_bytes_nint, irregular_nint.to_cbor_bytes()); + let irregular_text = NonOverlappingTypeChoiceSome::from_cbor_bytes(&irregular_bytes_text).unwrap(); + assert_eq!(irregular_bytes_text, irregular_text.to_cbor_bytes()); + } + } + } + + #[test] + fn non_overlapping_type_choice_all() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + let str_11_encodings = vec![ + StringLenSz::Len(Sz::One), + StringLenSz::Len(Sz::Inline), + StringLenSz::Indefinite(vec![(5, Sz::Two), (6, Sz::One)]), + StringLenSz::Indefinite(vec![(2, Sz::Inline), (0, Sz::Inline), (9, Sz::Four)]), + ]; + for str_enc in &str_11_encodings { + for def_enc in &def_encodings { + let irregular_bytes_uint = cbor_int(0, *def_enc); + let irregular_bytes_nint = cbor_int(-9, *def_enc); + let irregular_bytes_text = cbor_str_sz("abcdefghijk", str_enc.clone()); + let irregular_bytes_bytes = cbor_bytes_sz(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], str_enc.clone()); + let irregular_bytes_hello_world = vec![ + cbor_tag_sz(13, *def_enc), + cbor_str_sz("hello world", str_enc.clone()) + ].into_iter().flatten().clone().collect::>(); + let irregular_bytes_arr = vec![ + arr_sz(2, *def_enc), + cbor_int(1, *def_enc), + cbor_int(3, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let irregular_bytes_map = vec![ + map_sz(2, *def_enc), + cbor_str_sz("11111111111", str_enc.clone()), + cbor_int(1, *def_enc), + cbor_str_sz("33333333333", str_enc.clone()), + cbor_int(3, *def_enc), + ].into_iter().flatten().clone().collect::>(); + let irregular_uint = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_uint).unwrap(); + assert_eq!(irregular_bytes_uint, irregular_uint.to_cbor_bytes()); + let irregular_nint = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_nint).unwrap(); + assert_eq!(irregular_bytes_nint, irregular_nint.to_cbor_bytes()); + let irregular_text = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_text).unwrap(); + assert_eq!(irregular_bytes_text, irregular_text.to_cbor_bytes()); + let irregular_bytes = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_bytes).unwrap(); + assert_eq!(irregular_bytes_bytes, irregular_bytes.to_cbor_bytes()); + let irregular_hello_world = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_hello_world).unwrap(); + assert_eq!(irregular_bytes_hello_world, irregular_hello_world.to_cbor_bytes()); + let irregular_arr = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_arr).unwrap(); + assert_eq!(irregular_bytes_arr, irregular_arr.to_cbor_bytes()); + let irregular_map = NonOverlappingTypeChoiceAll::from_cbor_bytes(&irregular_bytes_map).unwrap(); + assert_eq!(irregular_bytes_map, irregular_map.to_cbor_bytes()); + } + } + } + #[test] fn enums() { let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight];