Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with overlapping inlined basic embedded enums #228

Merged
merged 4 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 79 additions & 44 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6552,22 +6552,30 @@ fn generate_enum(
// We avoid checking ALL variants if we can figure it out by instead checking the type.
// This only works when the variants don't have first types in common.
let mut non_overlapping_types_match = {
let mut first_types = BTreeSet::new();
let mut duplicates = false;
let mut all_first_types = BTreeSet::new();
let mut duplicates_or_unknown = false;
for variant in variants.iter() {
for first_type in variant.cbor_types(types) {
// to_byte(0) is used since cbor_event::Type doesn't implement
// Ord or Hash so we can't put it in a set. Since we fix the lenth
// to always 0 this still remains a 1-to-1 mapping to Type.
if !first_types.insert(first_type.to_byte(0)) {
duplicates = true;
match variant.cbor_types_inner(types, rep) {
Some(first_types) => {
for first_type in first_types.iter() {
// to_byte(0) is used since cbor_event::Type doesn't implement
// Ord or Hash so we can't put it in a set. Since we fix the lenth
// to always 0 this still remains a 1-to-1 mapping to Type.
if !all_first_types.insert(first_type.to_byte(0)) {
duplicates_or_unknown = true;
}
}
}
None => {
duplicates_or_unknown = true;
break;
}
}
}
if duplicates {
if duplicates_or_unknown {
None
} else {
let deser_covers_all_types = first_types.len() == 8;
let deser_covers_all_types = all_first_types.len() == 8;
Some((Block::new("match raw.cbor_type()?"), deser_covers_all_types))
}
};
Expand Down Expand Up @@ -6631,7 +6639,7 @@ fn generate_enum(
.iter()
.any(|field| field.rust_type.config.bounds.is_some());
// bounds checking should be handled by the called constructor here
let mut ctor = format!("{}::new(", variant.name);
let mut ctor = format!("{}::new(", ty.conceptual_type.for_variant());
for field in ctor_fields {
if output_comma {
ctor.push_str(", ");
Expand Down Expand Up @@ -6869,46 +6877,72 @@ fn generate_enum(
} else {
variant.name_as_var()
};
let (before, after) =
if cli.preserve_encodings || !variant.rust_type().is_fixed_value() {
(Cow::from(format!("let {var_names_str} = ")), ";")
} else {
(Cow::from(""), "")
};
let (before, after) = if cli.preserve_encodings
|| !variant.rust_type().is_fixed_value()
|| rep.is_some()
{
(Cow::from(format!("let {var_names_str} = ")), ";")
} else {
(Cow::from(""), "")
};
let mut variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
DeserializeBeforeAfter::new(&before, after, false),
DeserializeConfig::new(&variant.name_as_var()),
cli,
);
let names_without_outer = enum_gen_info.names_without_outer();
// we can avoid this ugly block and directly do it as a line possibly
if variant_deser_code.content.as_single_line().is_some()
&& names_without_outer.len() == 1
{
variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
DeserializeBeforeAfter::new(
&format!("Ok({}::{}(", name, variant.name),
"))",
false,
),
DeserializeConfig::new(&variant.name_as_var()),
cli,
);
} else if names_without_outer.is_empty() {
variant_deser_code
.content
.line(&format!("Ok({}::{})", name, variant.name));
if let Some(r) = rep {
let len_info = match ty.conceptual_type.resolve_alias_shallow() {
ConceptualRustType::Rust(ident) if types.is_plain_group(ident) => {
types.rust_struct(ident).unwrap().cbor_len_info(types)
}
_ => RustStructCBORLen::Fixed(1),
};
// this will never be 1 line so don't bother with the below cases
variant_deser_code =
surround_in_len_checks(variant_deser_code, len_info, r, cli);
if enum_gen_info.outer_vars == 0 {
variant_deser_code.content.line(&format!(
"Ok({}::{}({}))",
name, variant.name, var_names_str
));
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
}
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
// we can avoid this ugly block and directly do it as a line possibly
if variant_deser_code.content.as_single_line().is_some()
&& enum_gen_info.names.len() == 1
{
variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
DeserializeBeforeAfter::new(
&format!("Ok({}::{}(", name, variant.name),
"))",
false,
),
DeserializeConfig::new(&variant.name_as_var()),
cli,
);
} else if enum_gen_info.names.is_empty() {
variant_deser_code
.content
.line(&format!("Ok({}::{})", name, variant.name));
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
}
}
variant_deser_code
}
Expand All @@ -6923,7 +6957,8 @@ fn generate_enum(
),
};
let cbor_types_str = variant
.cbor_types(types)
.cbor_types_inner(types, rep)
.expect("Already checked above")
.into_iter()
.map(cbor_type_code_str)
.collect::<Vec<_>>()
Expand Down
60 changes: 54 additions & 6 deletions src/intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,7 @@ impl ConceptualRustType {
}

pub fn directly_wasm_exposable(&self, types: &IntermediateTypes) -> bool {
println!("{self:?}.directly_wasm_exposable()");
match self {
Self::Fixed(_) => false,
Self::Primitive(_) => true,
Expand Down Expand Up @@ -2045,13 +2046,60 @@ impl EnumVariant {
}
}

pub fn cbor_types(&self, types: &IntermediateTypes) -> Vec<CBORType> {
/// Gets the next CBOR type after the passed in rep (array/map) tag
/// Returns None if this is not possible and brute-force deserialization
/// trying every variant should be used instead
pub fn cbor_types_inner(
&self,
types: &IntermediateTypes,
outer_rep: Option<Representation>,
) -> Option<Vec<CBORType>> {
match &self.data {
EnumVariantData::RustType(ty) => ty.cbor_types(types),
EnumVariantData::Inlined(record) => match record.rep {
Representation::Array => vec![CBORType::Array],
Representation::Map => vec![CBORType::Map],
},
EnumVariantData::RustType(ty) => {
if ty.encodings.is_empty() && outer_rep.is_some() {
if let ConceptualRustType::Rust(ident) =
ty.conceptual_type.resolve_alias_shallow()
{
match types.rust_struct(ident).unwrap().variant() {
// we can't know this unless there's a way to provide this info
RustStructType::Extern => None,
RustStructType::Record(record) => {
let mut ret = vec![];
for field in record.fields.iter() {
ret.extend(field.rust_type.cbor_types(types));
if !field.optional {
break;
}
}
Some(ret)
}
RustStructType::GroupChoice { .. } => None,
_ => Some(ty.cbor_types(types)),
}
} else {
Some(ty.cbor_types(types))
}
} else {
Some(ty.cbor_types(types))
}
}
EnumVariantData::Inlined(record) => {
if outer_rep.is_some() {
let mut ret = vec![];
for field in record.fields.iter() {
ret.extend(field.rust_type.cbor_types(types));
if !field.optional {
break;
}
}
Some(ret)
} else {
Some(match record.rep {
Representation::Array => vec![CBORType::Array],
Representation::Map => vec![CBORType::Map],
})
}
}
}
}

Expand Down
11 changes: 8 additions & 3 deletions src/parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use std::collections::BTreeMap;

use crate::comment_ast::{merge_metadata, metadata_from_comments, RuleMetadata};
use crate::intermediate::{
AliasInfo, CDDLIdent, ConceptualRustType, EnumVariant, FixedValue, GenericDef, GenericInstance,
IntermediateTypes, ModuleScope, Primitive, Representation, RustField, RustIdent, RustRecord,
RustStruct, RustStructType, RustType, VariantIdent,
AliasInfo, CBOREncodingOperation, CDDLIdent, ConceptualRustType, EnumVariant, FixedValue,
GenericDef, GenericInstance, IntermediateTypes, ModuleScope, Primitive, Representation,
RustField, RustIdent, RustRecord, RustStruct, RustStructType, RustType, VariantIdent,
};
use crate::utils::{
append_number_if_duplicate, convert_to_camel_case, convert_to_snake_case,
Expand Down Expand Up @@ -1518,7 +1518,12 @@ pub fn parse_group(
if let ConceptualRustType::Rust(ident) = &ty.conceptual_type {
// we might need to generate it if not used elsewhere
types.set_rep_if_plain_group(parent_visitor, ident, rep, cli);
// manual match in case we expand operaitons later
types.is_plain_group(ident)
&& !ty.encodings.iter().any(|enc| match enc {
CBOREncodingOperation::Tagged(_) => true,
CBOREncodingOperation::CBORBytes => true,
})
} else {
false
};
Expand Down
55 changes: 55 additions & 0 deletions tests/core/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,61 @@ non_overlapping_type_choice_all = uint / nint / text / bytes / #6.30("hello worl

non_overlapping_type_choice_some = uint / nint / text

overlap_basic_embed = [
; @name identity
tag: 0 //
; @name x
tag: 1, hash: bytes .size 32
]

non_overlap_basic_embed = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, tag: 1
]

non_overlap_basic_embed_multi_fields = [
; @name first
x: uint, z: uint //
; @name second
y: text, z: uint
]

non_overlap_basic_embed_mixed = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, z: uint
]

bytes_uint = (bytes, uint)

non_overlap_basic_embed_mixed_explicit = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, z: uint //
; @name third
bytes_uint
]

basic = (uint, text)

basic_arr = [basic]

; not overlap since double array for second
non_overlap_basic_not_basic = [
; @name group
basic //
; @name group_arr
basic_arr //
; @name group_tagged
#6.11(basic) //
; @name group_bytes
bytes .cbor basic
]

enums = [
c_enum,
type_choice,
Expand Down
39 changes: 39 additions & 0 deletions tests/core/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,45 @@ mod tests {
deser_test(&NonOverlappingTypeChoiceSome::N64(10000));
deser_test(&NonOverlappingTypeChoiceSome::Text("Hello, World!".into()));
}

#[test]
fn overlap_basic_embed() {
deser_test(&OverlapBasicEmbed::new_identity());
deser_test(&OverlapBasicEmbed::new_x(vec![85; 32]).unwrap());
}

#[test]
fn non_overlap_basic_embed() {
deser_test(&NonOverlapBasicEmbed::new_first(100));
deser_test(&NonOverlapBasicEmbed::new_second("cddl".to_owned()));
}

#[test]
fn non_overlap_basic_embed_multi_fields() {
deser_test(&NonOverlapBasicEmbedMultiFields::new_first(100, 1_000_000));
deser_test(&NonOverlapBasicEmbedMultiFields::new_second("cddl".to_owned(), 0));
}

#[test]
fn non_overlap_basic_embed_mixed() {
deser_test(&NonOverlapBasicEmbedMixed::new_first(100));
deser_test(&NonOverlapBasicEmbedMixed::new_second("cddl".to_owned(), 0));
}

#[test]
fn non_overlap_basic_embed_mixed_explicit() {
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_first(100));
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_second("cddl".to_owned(), 0));
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_third(vec![0xBA, 0xAD, 0xF0, 0x0D], 4));
}

#[test]
fn non_overlap_basic_not_basic() {
deser_test(&NonOverlapBasicNotBasic::new_group(4, "basic".to_owned()));
deser_test(&NonOverlapBasicNotBasic::new_group_arr(Basic::new(4, "".to_owned())));
deser_test(&NonOverlapBasicNotBasic::new_group_tagged(0, " T A G G E D ".to_owned()));
deser_test(&NonOverlapBasicNotBasic::new_group_bytes(u64::MAX, "bytes .cbor basic".to_owned()));
}

#[test]
fn array_opt_fields() {
Expand Down
Loading
Loading