Skip to content

Commit

Permalink
Fix issue with overlapping inlined basic embedded enums (#228)
Browse files Browse the repository at this point in the history
* Fix issue with overlapping inlined basic embedded enums

Caused when the inlining would made cddl-codegen think that the types
were not overlapping since it was looking at the stored type not the
actual starting cbor type (e.g. when it was a fixed value).

This would cause a problem as the type matching introduced in #199
but only in very specific cases with basic groups starting with fixed
values.

* Fix #229

Possibly needs more tests covering combinations with non-basic groups
mixed with basic, tagged basic groups and optional fields (should work
though)

* Fix #230 #231 #232
  • Loading branch information
rooooooooob authored May 18, 2024
1 parent 200f1fe commit d7c8e19
Show file tree
Hide file tree
Showing 7 changed files with 480 additions and 53 deletions.
123 changes: 79 additions & 44 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6892,22 +6892,30 @@ fn generate_enum(
// We avoid checking ALL variants if we can figure it out by instead checking the type.
// This only works when the variants don't have first types in common.
let mut non_overlapping_types_match = {
let mut first_types = BTreeSet::new();
let mut duplicates = false;
let mut all_first_types = BTreeSet::new();
let mut duplicates_or_unknown = false;
for variant in variants.iter() {
for first_type in variant.cbor_types(types) {
// to_byte(0) is used since cbor_event::Type doesn't implement
// Ord or Hash so we can't put it in a set. Since we fix the lenth
// to always 0 this still remains a 1-to-1 mapping to Type.
if !first_types.insert(first_type.to_byte(0)) {
duplicates = true;
match variant.cbor_types_inner(types, rep) {
Some(first_types) => {
for first_type in first_types.iter() {
// to_byte(0) is used since cbor_event::Type doesn't implement
// Ord or Hash so we can't put it in a set. Since we fix the lenth
// to always 0 this still remains a 1-to-1 mapping to Type.
if !all_first_types.insert(first_type.to_byte(0)) {
duplicates_or_unknown = true;
}
}
}
None => {
duplicates_or_unknown = true;
break;
}
}
}
if duplicates {
if duplicates_or_unknown {
None
} else {
let deser_covers_all_types = first_types.len() == 8;
let deser_covers_all_types = all_first_types.len() == 8;
Some((Block::new("match raw.cbor_type()?"), deser_covers_all_types))
}
};
Expand Down Expand Up @@ -6971,7 +6979,7 @@ fn generate_enum(
.iter()
.any(|field| field.rust_type.config.bounds.is_some());
// bounds checking should be handled by the called constructor here
let mut ctor = format!("{}::new(", variant.name);
let mut ctor = format!("{}::new(", ty.conceptual_type.for_variant());
for field in ctor_fields {
if output_comma {
ctor.push_str(", ");
Expand Down Expand Up @@ -7209,46 +7217,72 @@ fn generate_enum(
} else {
variant.name_as_var()
};
let (before, after) =
if cli.preserve_encodings || !variant.rust_type().is_fixed_value() {
(Cow::from(format!("let {var_names_str} = ")), ";")
} else {
(Cow::from(""), "")
};
let (before, after) = if cli.preserve_encodings
|| !variant.rust_type().is_fixed_value()
|| rep.is_some()
{
(Cow::from(format!("let {var_names_str} = ")), ";")
} else {
(Cow::from(""), "")
};
let mut variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
DeserializeBeforeAfter::new(&before, after, false),
DeserializeConfig::new(&variant.name_as_var()),
cli,
);
let names_without_outer = enum_gen_info.names_without_outer();
// we can avoid this ugly block and directly do it as a line possibly
if variant_deser_code.content.as_single_line().is_some()
&& names_without_outer.len() == 1
{
variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
DeserializeBeforeAfter::new(
&format!("Ok({}::{}(", name, variant.name),
"))",
false,
),
DeserializeConfig::new(&variant.name_as_var()),
cli,
);
} else if names_without_outer.is_empty() {
variant_deser_code
.content
.line(&format!("Ok({}::{})", name, variant.name));
if let Some(r) = rep {
let len_info = match ty.conceptual_type.resolve_alias_shallow() {
ConceptualRustType::Rust(ident) if types.is_plain_group(ident) => {
types.rust_struct(ident).unwrap().cbor_len_info(types)
}
_ => RustStructCBORLen::Fixed(1),
};
// this will never be 1 line so don't bother with the below cases
variant_deser_code =
surround_in_len_checks(variant_deser_code, len_info, r, cli);
if enum_gen_info.outer_vars == 0 {
variant_deser_code.content.line(&format!(
"Ok({}::{}({}))",
name, variant.name, var_names_str
));
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
}
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
// we can avoid this ugly block and directly do it as a line possibly
if variant_deser_code.content.as_single_line().is_some()
&& enum_gen_info.names.len() == 1
{
variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
DeserializeBeforeAfter::new(
&format!("Ok({}::{}(", name, variant.name),
"))",
false,
),
DeserializeConfig::new(&variant.name_as_var()),
cli,
);
} else if enum_gen_info.names.is_empty() {
variant_deser_code
.content
.line(&format!("Ok({}::{})", name, variant.name));
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
}
}
variant_deser_code
}
Expand All @@ -7263,7 +7297,8 @@ fn generate_enum(
),
};
let cbor_types_str = variant
.cbor_types(types)
.cbor_types_inner(types, rep)
.expect("Already checked above")
.into_iter()
.map(cbor_type_code_str)
.collect::<Vec<_>>()
Expand Down
60 changes: 54 additions & 6 deletions src/intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,7 @@ impl ConceptualRustType {
}

pub fn directly_wasm_exposable(&self, types: &IntermediateTypes) -> bool {
println!("{self:?}.directly_wasm_exposable()");
match self {
Self::Fixed(_) => false,
Self::Primitive(_) => true,
Expand Down Expand Up @@ -2104,13 +2105,60 @@ impl EnumVariant {
}
}

pub fn cbor_types(&self, types: &IntermediateTypes) -> Vec<CBORType> {
/// Gets the next CBOR type after the passed in rep (array/map) tag
/// Returns None if this is not possible and brute-force deserialization
/// trying every variant should be used instead
pub fn cbor_types_inner(
&self,
types: &IntermediateTypes,
outer_rep: Option<Representation>,
) -> Option<Vec<CBORType>> {
match &self.data {
EnumVariantData::RustType(ty) => ty.cbor_types(types),
EnumVariantData::Inlined(record) => match record.rep {
Representation::Array => vec![CBORType::Array],
Representation::Map => vec![CBORType::Map],
},
EnumVariantData::RustType(ty) => {
if ty.encodings.is_empty() && outer_rep.is_some() {
if let ConceptualRustType::Rust(ident) =
ty.conceptual_type.resolve_alias_shallow()
{
match types.rust_struct(ident).unwrap().variant() {
// we can't know this unless there's a way to provide this info
RustStructType::Extern => None,
RustStructType::Record(record) => {
let mut ret = vec![];
for field in record.fields.iter() {
ret.extend(field.rust_type.cbor_types(types));
if !field.optional {
break;
}
}
Some(ret)
}
RustStructType::GroupChoice { .. } => None,
_ => Some(ty.cbor_types(types)),
}
} else {
Some(ty.cbor_types(types))
}
} else {
Some(ty.cbor_types(types))
}
}
EnumVariantData::Inlined(record) => {
if outer_rep.is_some() {
let mut ret = vec![];
for field in record.fields.iter() {
ret.extend(field.rust_type.cbor_types(types));
if !field.optional {
break;
}
}
Some(ret)
} else {
Some(match record.rep {
Representation::Array => vec![CBORType::Array],
Representation::Map => vec![CBORType::Map],
})
}
}
}
}

Expand Down
11 changes: 8 additions & 3 deletions src/parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use std::collections::BTreeMap;

use crate::comment_ast::{merge_metadata, metadata_from_comments, RuleMetadata};
use crate::intermediate::{
AliasInfo, CDDLIdent, ConceptualRustType, EnumVariant, FixedValue, GenericDef, GenericInstance,
IntermediateTypes, ModuleScope, Primitive, Representation, RustField, RustIdent, RustRecord,
RustStruct, RustStructType, RustType, VariantIdent,
AliasInfo, CBOREncodingOperation, CDDLIdent, ConceptualRustType, EnumVariant, FixedValue,
GenericDef, GenericInstance, IntermediateTypes, ModuleScope, Primitive, Representation,
RustField, RustIdent, RustRecord, RustStruct, RustStructType, RustType, VariantIdent,
};
use crate::utils::{
append_number_if_duplicate, convert_to_camel_case, convert_to_snake_case,
Expand Down Expand Up @@ -1534,7 +1534,12 @@ pub fn parse_group(
if let ConceptualRustType::Rust(ident) = &ty.conceptual_type {
// we might need to generate it if not used elsewhere
types.set_rep_if_plain_group(parent_visitor, ident, rep, cli);
// manual match in case we expand operaitons later
types.is_plain_group(ident)
&& !ty.encodings.iter().any(|enc| match enc {
CBOREncodingOperation::Tagged(_) => true,
CBOREncodingOperation::CBORBytes => true,
})
} else {
false
};
Expand Down
55 changes: 55 additions & 0 deletions tests/core/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,61 @@ non_overlapping_type_choice_all = uint / nint / text / bytes / #6.30("hello worl

non_overlapping_type_choice_some = uint / nint / text

overlap_basic_embed = [
; @name identity
tag: 0 //
; @name x
tag: 1, hash: bytes .size 32
]

non_overlap_basic_embed = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, tag: 1
]

non_overlap_basic_embed_multi_fields = [
; @name first
x: uint, z: uint //
; @name second
y: text, z: uint
]

non_overlap_basic_embed_mixed = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, z: uint
]

bytes_uint = (bytes, uint)

non_overlap_basic_embed_mixed_explicit = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, z: uint //
; @name third
bytes_uint
]

basic = (uint, text)

basic_arr = [basic]

; not overlap since double array for second
non_overlap_basic_not_basic = [
; @name group
basic //
; @name group_arr
basic_arr //
; @name group_tagged
#6.11(basic) //
; @name group_bytes
bytes .cbor basic
]

enums = [
c_enum,
type_choice,
Expand Down
39 changes: 39 additions & 0 deletions tests/core/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,45 @@ mod tests {
deser_test(&NonOverlappingTypeChoiceSome::N64(10000));
deser_test(&NonOverlappingTypeChoiceSome::Text("Hello, World!".into()));
}

#[test]
fn overlap_basic_embed() {
deser_test(&OverlapBasicEmbed::new_identity());
deser_test(&OverlapBasicEmbed::new_x(vec![85; 32]).unwrap());
}

#[test]
fn non_overlap_basic_embed() {
deser_test(&NonOverlapBasicEmbed::new_first(100));
deser_test(&NonOverlapBasicEmbed::new_second("cddl".to_owned()));
}

#[test]
fn non_overlap_basic_embed_multi_fields() {
deser_test(&NonOverlapBasicEmbedMultiFields::new_first(100, 1_000_000));
deser_test(&NonOverlapBasicEmbedMultiFields::new_second("cddl".to_owned(), 0));
}

#[test]
fn non_overlap_basic_embed_mixed() {
deser_test(&NonOverlapBasicEmbedMixed::new_first(100));
deser_test(&NonOverlapBasicEmbedMixed::new_second("cddl".to_owned(), 0));
}

#[test]
fn non_overlap_basic_embed_mixed_explicit() {
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_first(100));
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_second("cddl".to_owned(), 0));
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_third(vec![0xBA, 0xAD, 0xF0, 0x0D], 4));
}

#[test]
fn non_overlap_basic_not_basic() {
deser_test(&NonOverlapBasicNotBasic::new_group(4, "basic".to_owned()));
deser_test(&NonOverlapBasicNotBasic::new_group_arr(Basic::new(4, "".to_owned())));
deser_test(&NonOverlapBasicNotBasic::new_group_tagged(0, " T A G G E D ".to_owned()));
deser_test(&NonOverlapBasicNotBasic::new_group_bytes(u64::MAX, "bytes .cbor basic".to_owned()));
}

#[test]
fn array_opt_fields() {
Expand Down
Loading

0 comments on commit d7c8e19

Please sign in to comment.