diff --git a/src/generation.rs b/src/generation.rs index faf1768..1739a5e 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -713,7 +713,7 @@ impl GenerationScope { start_len(body, Representation::Array, serializer_use, &encoding_var, &format!("{}.len() as u64", config.expr)); let elem_var_name = format!("{}_elem", config.var_name); let elem_encs = if CLI_ARGS.preserve_encodings { - encoding_fields(&elem_var_name, (&ty.clone().resolve_aliases()).into()) + encoding_fields(&elem_var_name, &ty.clone().resolve_aliases(), false) } else { vec![] }; @@ -739,8 +739,8 @@ impl GenerationScope { SerializingRustType::Root(ConceptualRustType::Map(key, value)) => { start_len(body, Representation::Map, serializer_use, &encoding_var, &format!("{}.len() as u64", config.expr)); let ser_loop = if CLI_ARGS.preserve_encodings { - let key_enc_fields = encoding_fields(&format!("{}_key", config.var_name), (&key.clone().resolve_aliases()).into()); - let value_enc_fields = encoding_fields(&format!("{}_value", config.var_name), (&value.clone().resolve_aliases()).into()); + let key_enc_fields = encoding_fields(&format!("{}_key", config.var_name), &key.clone().resolve_aliases(), false); + let value_enc_fields = encoding_fields(&format!("{}_value", config.var_name), &value.clone().resolve_aliases(), false); let mut ser_loop = if CLI_ARGS.canonical_form { let mut key_order = codegen::Block::new(&format!("let mut key_order = {}.iter().map(|(k, v)|", config.expr)); key_order @@ -1074,7 +1074,7 @@ impl GenerationScope { } } let ty_enc_fields = if CLI_ARGS.preserve_encodings { - encoding_fields(var_name, (&ty.clone().resolve_aliases()).into()) + encoding_fields(var_name, &ty.clone().resolve_aliases(), false) } else { vec![] }; @@ -1140,7 +1140,7 @@ impl GenerationScope { body.line(&format!("let mut {} = Vec::new();", arr_var_name)); let elem_var_name = format!("{}_elem", var_name); let elem_encs = if CLI_ARGS.preserve_encodings { - encoding_fields(&elem_var_name, (&ty.clone().resolve_aliases()).into()) + encoding_fields(&elem_var_name, &ty.clone().resolve_aliases(), false) } else { vec![] }; @@ -1200,12 +1200,12 @@ impl GenerationScope { let key_var_name = format!("{}_key", var_name); let value_var_name = format!("{}_value", var_name); let key_encs = if CLI_ARGS.preserve_encodings { - encoding_fields(&key_var_name, (&key_type.clone().resolve_aliases()).into()) + encoding_fields(&key_var_name, &key_type.clone().resolve_aliases(), false) } else { vec![] }; let value_encs = if CLI_ARGS.preserve_encodings { - encoding_fields(&value_var_name, (&value_type.clone().resolve_aliases()).into()) + encoding_fields(&value_var_name, &value_type.clone().resolve_aliases(), false) } else { vec![] }; @@ -1356,11 +1356,6 @@ impl GenerationScope { // new for variant in variants.iter() { let variant_arg = variant.name_as_var(); - let enc_fields = if CLI_ARGS.preserve_encodings { - encoding_fields(&variant_arg, (&variant.rust_type.clone().resolve_aliases()).into()) - } else { - vec![] - }; let mut new_func = codegen::Function::new(&format!("new_{}", variant_arg)); new_func.vis("pub"); let can_fail = match &variant.name { @@ -2154,7 +2149,22 @@ fn key_encoding_field(name: &str, key: &FixedValue) -> EncodingField { } } -fn encoding_fields(name: &str, ty: SerializingRustType) -> Vec { +fn encoding_fields(name: &str, ty: &RustType, include_default: bool) -> Vec { + assert!(CLI_ARGS.preserve_encodings); + // TODO: how do we handle defaults for nested things? e.g. inside of a ConceptualRustType::Map + let mut encs = encoding_fields_impl(name, ty.into()); + if include_default && ty.default.is_some() { + encs.push(EncodingField { + field_name: format!("{}_default_present", name), + type_name: "bool".to_owned(), + default_expr: "false", + inner: Vec::new(), + }); + } + encs +} + +fn encoding_fields_impl(name: &str, ty: SerializingRustType) -> Vec { assert!(CLI_ARGS.preserve_encodings); match ty { SerializingRustType::Root(ConceptualRustType::Array(elem_ty)) => { @@ -2164,7 +2174,7 @@ fn encoding_fields(name: &str, ty: SerializingRustType) -> Vec { default_expr: "LenEncoding::default()", inner: Vec::new(), }; - let inner_encs = encoding_fields(&format!("{}_elem", name), (&**elem_ty).into()); + let inner_encs = encoding_fields_impl(&format!("{}_elem", name), (&**elem_ty).into()); if inner_encs.is_empty() { vec![base] } else { @@ -2193,8 +2203,8 @@ fn encoding_fields(name: &str, ty: SerializingRustType) -> Vec { inner: Vec::new(), } ]; - let key_encs = encoding_fields(&format!("{}_key", name), (&**k).into()); - let val_encs = encoding_fields(&format!("{}_value", name), (&**v).into()); + let key_encs = encoding_fields_impl(&format!("{}_key", name), (&**k).into()); + let val_encs = encoding_fields_impl(&format!("{}_value", name), (&**v).into()); if !key_encs.is_empty() { let type_name_value = if key_encs.len() == 1 { @@ -2256,21 +2266,21 @@ fn encoding_fields(name: &str, ty: SerializingRustType) -> Vec { SerializingRustType::Root(ConceptualRustType::Fixed(f)) => match f { FixedValue::Bool(_) | FixedValue::Null => vec![], - FixedValue::Nint(_) => encoding_fields(name, (&ConceptualRustType::Primitive(Primitive::I64)).into()), - FixedValue::Uint(_) => encoding_fields(name, (&ConceptualRustType::Primitive(Primitive::U64)).into()), - FixedValue::Text(_) => encoding_fields(name, (&ConceptualRustType::Primitive(Primitive::Str)).into()), + FixedValue::Nint(_) => encoding_fields_impl(name, (&ConceptualRustType::Primitive(Primitive::I64)).into()), + FixedValue::Uint(_) => encoding_fields_impl(name, (&ConceptualRustType::Primitive(Primitive::U64)).into()), + FixedValue::Text(_) => encoding_fields_impl(name, (&ConceptualRustType::Primitive(Primitive::Str)).into()), }, SerializingRustType::Root(ConceptualRustType::Alias(_, _)) => panic!("resolve types before calling this"), - SerializingRustType::Root(ConceptualRustType::Optional(ty)) => encoding_fields(name, (&**ty).into()), + SerializingRustType::Root(ConceptualRustType::Optional(ty)) => encoding_fields(name, ty, false), SerializingRustType::Root(ConceptualRustType::Rust(_)) => vec![], SerializingRustType::EncodingOperation(CBOREncodingOperation::Tagged(tag), child) => { - let mut encs = encoding_fields(&format!("{}_tag", name), (&ConceptualRustType::Fixed(FixedValue::Uint(*tag))).into()); - encs.append(&mut encoding_fields(name, *child)); + let mut encs = encoding_fields_impl(&format!("{}_tag", name), (&ConceptualRustType::Fixed(FixedValue::Uint(*tag))).into()); + encs.append(&mut encoding_fields_impl(name, *child)); encs }, SerializingRustType::EncodingOperation(CBOREncodingOperation::CBORBytes, child) => { - let mut encs = encoding_fields(&format!("{}_bytes", name), (&ConceptualRustType::Primitive(Primitive::Bytes)).into()); - encs.append(&mut encoding_fields(name, *child)); + let mut encs = encoding_fields_impl(&format!("{}_bytes", name), (&ConceptualRustType::Primitive(Primitive::Bytes)).into()); + encs.append(&mut encoding_fields_impl(name, *child)); encs }, } @@ -2284,7 +2294,7 @@ fn encoding_var_names_str(field_name: &str, rust_type: &RustType) -> String { } else { vec![field_name.to_owned()] }; - for enc in encoding_fields(field_name, (&resolved_rust_type).into()).into_iter() { + for enc in encoding_fields(field_name, &resolved_rust_type, false).into_iter() { var_names.push(enc.field_name); } let var_names_str = if var_names.len() > 1 { @@ -2321,20 +2331,35 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na setter .arg_mut_self() .arg(&field.name, &field.rust_type.for_wasm_param()) - .vis("pub") - .line(format!( + .vis("pub"); + if field.rust_type.default.is_some() { + setter.line(format!( + "self.0.{} = {}", + field.name, + ToWasmBoundaryOperations::format(field.rust_type.from_wasm_boundary_clone(&field.name, false).into_iter()))); + } else { + setter.line(format!( "self.0.{} = Some({})", field.name, ToWasmBoundaryOperations::format(field.rust_type.from_wasm_boundary_clone(&field.name, false).into_iter()))); + } + // ^ TODO: check types.can_new_fail(&field.name) wrapper.s_impl.push_fn(setter); // getter let mut getter = codegen::Function::new(&field.name); getter .arg_ref_self() - .ret(format!("Option<{}>", field.rust_type.for_wasm_return())) - .vis("pub") - .line(field.rust_type.to_wasm_boundary_optional(&format!("self.0.{}", field.name), false)); + .vis("pub"); + if field.rust_type.default.is_some() { + getter + .ret(field.rust_type.for_wasm_return()) + .line(field.rust_type.to_wasm_boundary(&format!("self.0.{}", field.name), false)); + } else { + getter + .ret(format!("Option<{}>", field.rust_type.for_wasm_return())) + .line(field.rust_type.to_wasm_boundary_optional(&format!("self.0.{}", field.name), false)); + } wrapper.s_impl.push_fn(getter); } else { // new @@ -2374,7 +2399,12 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na } // Fixed values only exist in (de)serialization code (outside of preserve-encodings=true) if !field.rust_type.is_fixed_value() { - if field.optional { + if let Some(default_value) = &field.rust_type.default { + // field + native_struct.field(&format!("pub {}", field.name), field.rust_type.for_rust_member(false)); + // new + native_new_block.line(format!("{}: {},", field.name, default_value.to_primitive_str_assign())); + } else if field.optional { // field native_struct.field(&format!("pub {}", field.name), format!("Option<{}>", field.rust_type.for_rust_member(false))); // new @@ -2403,7 +2433,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na } for field in &record.fields { // even fixed values still need to keep track of their encodings - for field_enc in encoding_fields(&field.name, (&field.rust_type.clone().resolve_aliases()).into()) { + for field_enc in encoding_fields(&field.name, &field.rust_type.clone().resolve_aliases(), true) { encoding_struct.field(&format!("pub {}", field_enc.field_name), field_enc.type_name); } if record.rep == Representation::Map { @@ -2493,12 +2523,21 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na let mut deser_ret = Block::new(&format!("Ok({}", name)); for field in record.fields.iter() { if field.optional { - let mut optional_array_ser_block = Block::new(&format!("if let Some(field) = &self.{}", field.name)); + let (optional_field_check, field_expr, expr_is_ref) = if let Some(default_value) = &field.rust_type.default { + (if CLI_ARGS.preserve_encodings { + format!("if self.{} != {} || self.encodings.map(|encs| encs.{}_default_present).unwrap_or(false)", field.name, default_value.to_primitive_str_compare(), field.name) + } else { + format!("if self.{} != {}", field.name, default_value.to_primitive_str_compare()) + }, format!("self.{}", field.name), false) + } else { + (format!("if let Some(field) = &self.{}", field.name), "field".to_owned(), true) + }; + let mut optional_array_ser_block = Block::new(&optional_field_check); gen_scope.generate_serialize( types, (&field.rust_type).into(), &mut optional_array_ser_block, - SerializeConfig::new("field", &field.name).expr_is_ref(true).encoding_var_in_option_struct("self.encodings")); + SerializeConfig::new(&field_expr, &field.name).expr_is_ref(expr_is_ref).encoding_var_in_option_struct("self.encodings")); ser_func.push_block(optional_array_ser_block); gen_scope.dont_generate_deserialize(name, format!("Array with optional field {}: {}", field.name, field.rust_type.for_rust_member(false))); } else { @@ -2552,7 +2591,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na for field in record.fields.iter() { // we don't support deserialization for optional fields so don't even bother if !field.optional { - for field_enc in encoding_fields(&field.name, (&field.rust_type.clone().resolve_aliases()).into()) { + for field_enc in encoding_fields(&field.name, &field.rust_type.clone().resolve_aliases(), true) { encoding_ctor.line(format!("{},", field_enc.field_name)); } } @@ -2589,7 +2628,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na } // declare variables for deser loop if CLI_ARGS.preserve_encodings { - for field_enc in encoding_fields(&field.name, (&field.rust_type.clone().resolve_aliases()).into()) { + for field_enc in encoding_fields(&field.name, &field.rust_type.clone().resolve_aliases(), true) { deser_body.line(&format!("let mut {} = {};", field_enc.field_name, field_enc.default_expr)); } let key_enc = key_encoding_field(&field.name, &field.key.as_ref().unwrap()); @@ -2600,7 +2639,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na } else { deser_body.line(&format!("let mut {} = None;", field.name)); } - let (data_name, expr_is_ref) = if field.optional { + let (data_name, expr_is_ref) = if field.optional && field.rust_type.default.is_none() { (String::from("field"), true) } else { (format!("self.{}", field.name), false) @@ -2634,7 +2673,6 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na let temp_var_prefix = format!("tmp_{}", field.name); let var_names_str = encoding_var_names_str(&temp_var_prefix, &field.rust_type); - let needs_vars = !var_names_str.is_empty(); let (before, after) = if var_names_str.is_empty() { ("".to_owned(), "?") } else { @@ -2659,7 +2697,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na } else { deser_block.line(format!("{} = Some(tmp_{});", field.name, field.name)); } - for enc_field in encoding_fields(&field.name, (&field.rust_type.clone().resolve_aliases()).into()) { + for enc_field in encoding_fields(&field.name, &field.rust_type.clone().resolve_aliases(), false) { deser_block.line(format!("{} = tmp_{};", enc_field.field_name, enc_field.field_name)); } } else { @@ -2757,7 +2795,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na // ser_loop_match.line(format!("{} => {},")); //} else { //} - let mut field_ser_block = if field.optional { + let mut field_ser_block = if field.optional && field.rust_type.default.is_none() { Block::new(&format!("{} => if let Some(field) = &self.{}", field_index, field.name)) } else { Block::new(&format!("{} =>", field_index)) @@ -2773,7 +2811,12 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na } else { for (_field_index, field, content) in ser_content.into_iter() { if field.optional { - let mut optional_ser_field = Block::new(&format!("if let Some(field) = &self.{}", field.name)); + let optional_ser_field_check = if let Some(default_value) = &field.rust_type.default { + format!("if self.{} != {}", field.name, default_value.to_primitive_str_compare()) + } else { + format!("if let Some(field) = &self.{}", field.name) + }; + let mut optional_ser_field = Block::new(&optional_ser_field_check); optional_ser_field.push_all(content); ser_func.push_block(optional_ser_field); } else { @@ -2854,11 +2897,18 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na } else { let mut mandatory_field_check = Block::new(&format!("let {} = match {}", field.name, field.name)); mandatory_field_check.line("Some(x) => x,"); - + mandatory_field_check.line(format!("None => return Err(DeserializeFailure::MandatoryFieldMissing({}).into()),", key)); mandatory_field_check.after(";"); deser_body.push_block(mandatory_field_check); } + } else if let Some(default_value) = &field.rust_type.default { + if CLI_ARGS.preserve_encodings { + let mut default_present_check = Block::new(&format!("if {} == Some({})", field.name, default_value.to_primitive_str_assign())); + default_present_check.line(format!("{}_default_present = true;", field.name)); + deser_body.push_block(default_present_check); + } + deser_body.line(&format!("let {} = {}.unwrap_or({});", field.name, field.name, default_value.to_primitive_str_assign())); } if !field.rust_type.is_fixed_value() { ctor_block.line(format!("{},", field.name)); @@ -2875,7 +2925,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na for field in record.fields.iter() { let key_enc = key_encoding_field(&field.name, field.key.as_ref().unwrap()); encoding_ctor.line(format!("{},", key_enc.field_name)); - for field_enc in encoding_fields(&field.name, (&field.rust_type.clone().resolve_aliases()).into()) { + for field_enc in encoding_fields(&field.name, &field.rust_type.clone().resolve_aliases(), true) { encoding_ctor.line(format!("{},", field_enc.field_name)); } } @@ -3078,7 +3128,7 @@ impl EnumVariantInRust { fn new(variant: &EnumVariant, rep: Option) -> Self { let name = variant.name_as_var(); let mut enc_fields = if CLI_ARGS.preserve_encodings { - encoding_fields(&name, (&variant.rust_type.clone().resolve_aliases()).into()) + encoding_fields(&name, &variant.rust_type.clone().resolve_aliases(), true) } else { vec![] }; @@ -3547,7 +3597,7 @@ fn generate_wrapper_struct(gen_scope: &mut GenerationScope, types: &Intermediate let encoding_name = RustIdent::new(CDDLIdent::new(format!("{}Encoding", type_name))); let enc_fields = if CLI_ARGS.preserve_encodings { s.field("pub inner", field_type.for_rust_member(false)); - let enc_fields = encoding_fields("inner", (&field_type.clone().resolve_aliases()).into()); + let enc_fields = encoding_fields("inner", &field_type.clone().resolve_aliases(), true); if !enc_fields.is_empty() { s.field(&format!("{}pub encodings", encoding_var_macros(types.used_as_key(type_name))), format!("Option<{}>", encoding_name)); diff --git a/src/intermediate.rs b/src/intermediate.rs index 631e599..8ab418c 100644 --- a/src/intermediate.rs +++ b/src/intermediate.rs @@ -435,6 +435,27 @@ impl FixedValue { }.expect("Unable to serialize key for canonical ordering"); buf.finalize() } + + /// Converts a literal to a valid rust expression capable of initializing a Primitive + /// e.g. Text is an actual String, etc + pub fn to_primitive_str_assign(&self) -> String { + match self { + FixedValue::Null => "None".to_owned(), + FixedValue::Bool(b) => b.to_string(), + FixedValue::Nint(i) => i.to_string(), + FixedValue::Uint(u) => u.to_string(), + FixedValue::Text(s) => format!("\"{}\".to_owned()", s), + } + } + + /// Converts a literal to a valid rust comparison valid for comparisons + /// e.g. Text can be &str to avoid creating a String + pub fn to_primitive_str_compare(&self) -> String { + match self { + FixedValue::Text(s) => format!("\"{}\"", s), + _=> self.to_primitive_str_assign(), + } + } } #[derive(Clone, Debug, Eq, PartialEq)] @@ -498,6 +519,7 @@ impl Primitive { }) } + /// All POSSIBLE outermost CBOR types this can encode to pub fn cbor_types(&self) -> Vec { match self { Primitive::Bool => vec![CBORType::Special], @@ -689,7 +711,20 @@ impl RustType { pub fn default(mut self, default_value: FixedValue) -> Self { assert!(self.default.is_none()); - // TODO: verify that the fixed value makes sense for the conceptual_type + let matches = if let ConceptualRustType::Primitive(p) = self.conceptual_type.clone().resolve_aliases() { + match &default_value { + FixedValue::Bool(_) => p == Primitive::Bool, + FixedValue::Nint(_) => p.cbor_types().contains(&CBORType::NegativeInteger), + FixedValue::Uint(_) => p.cbor_types().contains(&CBORType::UnsignedInteger), + FixedValue::Null => false, + FixedValue::Text(_) => p == Primitive::Str, + } + } else { + false + }; + if !matches { + panic!(".default {:?} invalid for type {:?}", default_value, self.conceptual_type); + } self.default = Some(default_value); self } @@ -707,6 +742,7 @@ impl RustType { } } + /// All POSSIBLE outermost CBOR types this can encode to pub fn cbor_types(&self) -> Vec { match self.encodings.last() { Some(CBOREncodingOperation::Tagged(_)) => vec![CBORType::Tag], @@ -970,7 +1006,7 @@ impl ConceptualRustType { } } - /// IDENTIFIER for an enum variant. (Use for_rust_member() for the ) + /// IDENTIFIER for an enum variant. (Use for_rust_member() for the enum value) pub fn for_variant(&self) -> VariantIdent { match self { Self::Fixed(f) => f.for_variant(), @@ -991,17 +1027,7 @@ impl ConceptualRustType { /// can_fail is for cases where checks (e.g. range checks) are done if there /// is a type transformation (i.e. wrapper types) like text (wasm) -> #6.14(text) (rust) pub fn from_wasm_boundary_clone(&self, expr: &str, can_fail: bool) -> Vec { - //assert!(matches!(self, Self::Tagged(_, _)) || !can_fail); - match self { - // Self::Tagged(_tag, ty) => { - // let mut inner = ty.from_wasm_boundary_clone(expr, can_fail); - // if can_fail { - // inner.push(ToWasmBoundaryOperations::TryInto); - // } else { - // inner.push(ToWasmBoundaryOperations::Into); - // } - // inner - // }, + let mut ops = match self { Self::Rust(_ident) => vec![ ToWasmBoundaryOperations::Code(format!("{}.clone()", expr)), ToWasmBoundaryOperations::Into, @@ -1021,22 +1047,16 @@ impl ConceptualRustType { ToWasmBoundaryOperations::Into, ], _ => vec![ToWasmBoundaryOperations::Code(expr.to_owned())], + }; + if can_fail { + ops.push(ToWasmBoundaryOperations::TryInto); } + ops } fn from_wasm_boundary_clone_optional(&self, expr: &str, can_fail: bool) -> Vec { - //assert!(matches!(self, Self::Tagged(_, _)) || !can_fail); - match self { + let mut ops = match self { Self::Primitive(_p) => vec![ToWasmBoundaryOperations::Code(expr.to_owned())], - // Self::Tagged(_tag, ty) => { - // let mut inner = ty.from_wasm_boundary_clone_optional(expr, can_fail); - // if can_fail { - // inner.push(ToWasmBoundaryOperations::TryInto); - // } else { - // inner.push(ToWasmBoundaryOperations::Into); - // } - // inner - // }, Self::Alias(_ident, ty) => ty.from_wasm_boundary_clone_optional(expr, can_fail), Self::Array(..) | Self::Rust(..) | @@ -1049,7 +1069,11 @@ impl ConceptualRustType { }, ], _ => panic!("unsupported or unexpected"), + }; + if can_fail { + ops.push(ToWasmBoundaryOperations::TryInto); } + ops } /// for non-owning parameter TYPES from wasm @@ -1612,7 +1636,24 @@ impl RustRecord { // maps are defined by their keys instead (although they shouldn't have multi-length values either...) Representation::Map => ("_", String::from("1")), }; - conditional_field_expr.push_str(&format!("match &self.{} {{ Some({}) => {}, None => 0 }}", field.name, field_expr, field_contribution)); + if let Some(default_value) = &field.rust_type.default { + if CLI_ARGS.preserve_encodings { + conditional_field_expr.push_str(&format!( + "if self.{} != {} || self.encodings.as_ref().map(|encs| encs.{}_default_present).unwrap_or(false) {{ {} }} else {{ 0 }}", + field.name, + default_value.to_primitive_str_compare(), + field.name, + field_contribution)); + } else { + conditional_field_expr.push_str(&format!( + "if self.{} != {} {{ {} }} else {{ 0 }}", + field.name, + default_value.to_primitive_str_compare(), + field_contribution)); + } + } else { + conditional_field_expr.push_str(&format!("match &self.{} {{ Some({}) => {}, None => 0 }}", field.name, field_expr, field_contribution)); + } } else { match self.rep { Representation::Array => match field.rust_type.conceptual_type.expanded_field_count(types) { diff --git a/src/parsing.rs b/src/parsing.rs index e10fdaf..0e7c9f9 100644 --- a/src/parsing.rs +++ b/src/parsing.rs @@ -32,6 +32,7 @@ use crate::utils::{ enum ControlOperator { Range((Option, Option)), CBOR(RustType), + Default(FixedValue), } struct Type2AndParent<'a> { @@ -143,6 +144,15 @@ fn type2_to_number_literal(type2: &Type2) -> isize { } } +fn type2_to_fixed_value(type2: &Type2) -> FixedValue { + match type2 { + Type2::UintValue{ value, .. } => FixedValue::Uint(*value), + Type2::IntValue{ value, .. } => FixedValue::Nint(*value), + Type2::TextValue{ value, .. } => FixedValue::Text(value.to_string()), + _ => panic!("Type2: {:?} does not correspond to a supported FixedValue", type2), + } +} + fn parse_control_operator(types: &mut IntermediateTypes, parent: &Type2AndParent, operator: &Operator) -> ControlOperator { let lower_bound = match parent.type2 { Type2::Typename{ ident, .. } if ident.to_string() == "uint" => Some(0), @@ -165,10 +175,10 @@ fn parse_control_operator(types: &mut IntermediateTypes, parent: &Type2AndParent ControlOperator::Range((Some(range_start as i128), Some(if is_inclusive { range_end as i128 } else { (range_end + 1) as i128 }))) }, RangeCtlOp::CtlOp{ ctrl, .. } => match ctrl { - cddl::token::ControlOperator::DEFAULT | cddl::token::ControlOperator::CBORSEQ | cddl::token::ControlOperator::WITHIN | cddl::token::ControlOperator::AND => todo!("control operator {} not supported", ctrl), + cddl::token::ControlOperator::DEFAULT => ControlOperator::Default(type2_to_fixed_value(&operator.type2)), cddl::token::ControlOperator::CBOR => ControlOperator::CBOR(rust_type_from_type2(types, &Type2AndParent { type2: &operator.type2, parent: parent.parent, })), cddl::token::ControlOperator::EQ => ControlOperator::Range((Some(type2_to_number_literal(&operator.type2) as i128), Some(type2_to_number_literal(&operator.type2) as i128))), // TODO: this would be MUCH nicer (for error displaying, etc) to handle this in its own dedicated way @@ -286,6 +296,11 @@ fn parse_type(types: &mut IntermediateTypes, type_name: &RustIdent, type_choice: }, _ => panic!(".cbor is only allowed on bytes as per CDDL spec"), }, + ControlOperator::Default(default_value) => { + let default_type = rust_type_from_type2(types, &Type2AndParent { type2: &type1.type2, parent: &type1 }) + .default(default_value); + types.register_type_alias(type_name.clone(), default_type, true, true); + }, } }, None => { @@ -365,6 +380,12 @@ fn parse_type(types: &mut IntermediateTypes, type_name: &RustIdent, type_choice: _ => types.register_rust_struct(RustStruct::new_wrapper(type_name.clone(), *tag, new_type, Some(min_max))) } }, + Some(ControlOperator::Default(default_value)) => { + let default_tagged_type = rust_type_from_type2(types, &Type2AndParent { parent: &inner_type.type1, type2: &inner_type.type1.type2 }) + .default(default_value) + .tag(tag_unwrap); + types.register_type_alias(type_name.clone(), default_tagged_type, true, true); + }, None => { // TODO: this would be fixed if we ordered definitions via a dependency graph to begin with // which would also allow us to do a single pass instead of many like we do now @@ -599,19 +620,24 @@ fn group_entry_to_raw_field_name(entry: &GroupEntry) -> Option { fn rust_type_from_type1(types: &mut IntermediateTypes, type1: &Type1) -> RustType { let control = type1.operator.as_ref().map(|op| parse_control_operator(types, &Type2AndParent { parent: type1, type2: &type1.type2 }, op)); + let base_type = rust_type_from_type2(types, &Type2AndParent { type2: &type1.type2, parent: type1, }); // println!("type1: {:#?}", type1); match control { - Some(ControlOperator::CBOR(ty)) => ty.as_bytes(), + Some(ControlOperator::CBOR(ty)) => { + assert!(matches!(base_type.conceptual_type.resolve_aliases(), ConceptualRustType::Primitive(Primitive::Bytes))); + ty.as_bytes() + }, Some(ControlOperator::Range(min_max)) => { match &type1.type2 { Type2::Typename{ ident, .. } if ident.to_string() == "uint" || ident.to_string() == "int" => match range_to_primitive(min_max.0, min_max.1) { Some(t) => t.into(), None => panic!("unsupported range for {:?}: {:?}", ident.to_string().as_str(), control) }, - _ => rust_type_from_type2(types, &Type2AndParent { type2: &type1.type2, parent: type1, }) + _ => base_type } }, - _ => rust_type_from_type2(types, &Type2AndParent { type2: &type1.type2, parent: type1, }) + Some(ControlOperator::Default(default_value)) => base_type.default(default_value), + None => base_type, } } diff --git a/tests/core/input.cddl b/tests/core/input.cddl index f2f34dd..d5a100c 100644 --- a/tests/core/input.cddl +++ b/tests/core/input.cddl @@ -61,4 +61,11 @@ signed_ints = [ ; The fix would be ideal as even though the true min in CBOR would be -u64::MAX ; we can't test that since isize::BITS is never > 64 in any normal system and likely never will be i64_min: -9223372036854775808 -] \ No newline at end of file +] + +default_uint = uint .default 1337 + +map_with_defaults = { + ? 1 : default_uint + ? 2 : text .default "two" +} \ No newline at end of file diff --git a/tests/core/tests.rs b/tests/core/tests.rs index 512c9cb..324f4ec 100644 --- a/tests/core/tests.rs +++ b/tests/core/tests.rs @@ -126,4 +126,14 @@ mod tests { let max = SignedInts::new(u8::MAX, u16::MAX, u32::MAX, u64::MAX, i8::MAX, i16::MAX, i32::MAX, i64::MAX, u64::MAX); deser_test(&max); } + + #[test] + fn defaults() { + let mut md = MapWithDefaults::new(); + deser_test(&md); + md.key_1 = 0; + deser_test(&md); + md.key_2 = "not two".into(); + deser_test(&md); + } } diff --git a/tests/preserve-encodings/input.cddl b/tests/preserve-encodings/input.cddl index 70c279e..94696d2 100644 --- a/tests/preserve-encodings/input.cddl +++ b/tests/preserve-encodings/input.cddl @@ -56,4 +56,11 @@ signed_ints = [ ; The fix would be ideal as even though the true min in CBOR would be -u64::MAX ; we can't test that since isize::BITS is never > 64 in any normal system and likely never will be i64_min: -9223372036854775808 -] \ No newline at end of file +] + +default_uint = uint .default 1337 + +map_with_defaults = { + ? 1 : default_uint + ? 2 : text .default "two" +} \ No newline at end of file diff --git a/tests/preserve-encodings/tests.rs b/tests/preserve-encodings/tests.rs index 7ddc9c5..2b86db3 100644 --- a/tests/preserve-encodings/tests.rs +++ b/tests/preserve-encodings/tests.rs @@ -506,4 +506,48 @@ mod tests { assert_eq!(irregular_bytes_max, irregular_max.to_bytes()); } } + + #[test] + fn defaults() { + let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight]; + let str_3_encodings = vec![ + StringLenSz::Len(Sz::Eight), + StringLenSz::Len(Sz::Inline), + StringLenSz::Indefinite(vec![(1, Sz::Two), (2, Sz::One)]), + StringLenSz::Indefinite(vec![(2, Sz::Inline), (0, Sz::Inline), (1, Sz::Four)]), + ]; + let bools = [(false, true), (true, false), (true, true)]; + for str_enc in &str_3_encodings { + for def_enc in &def_encodings { + for ((key_1_present, key_1_default), (key_2_present, key_2_default)) in bools.iter().zip(bools.iter()) { + let value_1: u64 = if *key_1_default { 1337 } else { 2 }; + let value_2 = if *key_2_default { "two" } else { "one" }; + let irregular_bytes = vec![ + vec![MAP_INDEF], + if *key_1_present { + vec![ + cbor_int(1, *def_enc), + cbor_int(value_1 as i128, Sz::Two), + ].into_iter().flatten().clone().collect::>() + } else { + vec![] + }, + if *key_2_present { + vec![ + cbor_int(2, *def_enc), + cbor_str_sz(value_2, str_enc.clone()), + ].into_iter().flatten().clone().collect::>() + } else { + vec![] + }, + vec![BREAK], + ].into_iter().flatten().clone().collect::>(); + let irregular = MapWithDefaults::from_bytes(irregular_bytes.clone()).unwrap(); + assert_eq!(irregular_bytes, irregular.to_bytes()); + assert_eq!(irregular.key_1, value_1); + assert_eq!(irregular.key_2, value_2); + } + } + } + } } \ No newline at end of file