diff --git a/crates/cairo-lang-plugins/src/plugins/derive.rs b/crates/cairo-lang-plugins/src/plugins/derive.rs index 99ddff3060d..e372052787a 100644 --- a/crates/cairo-lang-plugins/src/plugins/derive.rs +++ b/crates/cairo-lang-plugins/src/plugins/derive.rs @@ -7,10 +7,12 @@ use cairo_lang_semantic::plugin::{AsDynMacroPlugin, SemanticPlugin, TrivialPlugi use cairo_lang_syntax::attribute::structured::{ AttributeArg, AttributeArgVariant, AttributeStructurize, }; -use cairo_lang_syntax::node::ast::{AttributeList, MemberList}; +use cairo_lang_syntax::node::ast::{ + AttributeList, ItemStruct, MemberList, OptionWrappedGenericParamList, +}; use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::helpers::QueryAttrs; -use cairo_lang_syntax::node::{ast, Terminal}; +use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode}; use indoc::formatdoc; use itertools::Itertools; use smol_str::SmolStr; @@ -26,7 +28,7 @@ impl MacroPlugin for DerivePlugin { db, struct_ast.name(db), struct_ast.attributes(db), - ExtraInfo::Struct(member_names(db, struct_ast.members(db))), + extract_struct_extra_info(db, &struct_ast), ), ast::Item::Enum(enum_ast) => generate_derive_code_for_type( db, @@ -56,7 +58,7 @@ impl SemanticPlugin for DerivePlugin {} enum ExtraInfo { Enum(Vec), - Struct(Vec), + Struct { members: Vec, type_generics: Vec, other_generics: Vec }, Extern, } @@ -64,6 +66,53 @@ fn member_names(db: &dyn SyntaxGroup, members: MemberList) -> Vec { members.elements(db).into_iter().map(|member| member.name(db).text(db)).collect() } +fn extract_struct_extra_info(db: &dyn SyntaxGroup, struct_ast: &ItemStruct) -> ExtraInfo { + let members = member_names(db, struct_ast.members(db)); + let mut type_generics = vec![]; + let mut other_generics = vec![]; + match struct_ast.generic_params(db) { + OptionWrappedGenericParamList::WrappedGenericParamList(gens) => gens + .generic_params(db) + .elements(db) + .into_iter() + .map(|member| match member { + ast::GenericParam::Type(t) => { + type_generics.push(t.name(db).text(db)); + } + ast::GenericParam::Impl(i) => { + other_generics.push(i.as_syntax_node().get_text_without_trivia(db)) + } + ast::GenericParam::Const(c) => { + other_generics.push(c.as_syntax_node().get_text_without_trivia(db)) + } + }) + .collect(), + OptionWrappedGenericParamList::Empty(_) => vec![], + }; + ExtraInfo::Struct { members, type_generics, other_generics } +} + +fn format_generics_with_trait( + type_generics: &[SmolStr], + other_generics: &[String], + f: impl Fn(&SmolStr) -> String, +) -> String { + format!( + "<{}{}{}>", + type_generics.iter().map(|s| format!("{}, ", s)).collect::(), + other_generics.iter().map(|s| format!("{}, ", s)).collect::(), + type_generics.iter().map(f).join(", "), + ) +} + +fn format_generics(type_generics: &[SmolStr], other_generics: &[String]) -> String { + format!( + "<{}{}>", + type_generics.iter().map(|s| format!("{}, ", s)).collect::(), + other_generics.iter().map(|s| format!("{}, ", s)).collect::(), + ) +} + /// Adds an implementation for all requested derives for the type. fn generate_derive_code_for_type( db: &dyn SyntaxGroup, @@ -107,7 +156,7 @@ fn generate_derive_code_for_type( let name = ident.text(db); let derived = segment.ident(db).text(db); match derived.as_str() { - "Copy" | "Drop" => impls.push(get_empty_impl(&name, &derived)), + "Copy" | "Drop" => impls.push(get_empty_impl(&name, &derived, &extra_info)), "Clone" if !matches!(extra_info, ExtraInfo::Extern) => { impls.push(get_clone_impl(&name, &extra_info)) } @@ -163,18 +212,22 @@ fn get_clone_impl(name: &str, extra_info: &ExtraInfo) -> String { format!("{name}::{variant}(x) => {name}::{variant}(x.clone()),") }).join("\n ")} } - ExtraInfo::Struct(members) => { + ExtraInfo::Struct { members, type_generics, other_generics } => { formatdoc! {" - impl {name}Clone of Clone::<{name}> {{ - fn clone(self: @{name}) -> {name} {{ + impl {name}Clone{generics_impl} of Clone::<{name}{generics}> {{ + fn clone(self: @{name}{generics}) -> {name}{generics} {{ {name} {{ {} }} }} }} ", members.iter().map(|member| { - format!("{member}: self.{member}.clone(),") - }).join("\n ")} + format!("{member}: self.{member}.clone(),") + }).join("\n "), + generics = format_generics(type_generics, other_generics), + generics_impl = format_generics_with_trait(type_generics, other_generics, + |t| format!("impl {t}Clone: Clone<{t}>, impl {t}Destruct: Destruct<{t}>")) + } } ExtraInfo::Extern => unreachable!(), } @@ -195,16 +248,20 @@ fn get_destruct_impl(name: &str, extra_info: &ExtraInfo) -> String { format!("{name}::{variant}(x) => traits::Destruct::destruct(x),") }).join("\n ")} } - ExtraInfo::Struct(members) => { + ExtraInfo::Struct { members, type_generics, other_generics } => { formatdoc! {" - impl {name}Destruct of Destruct::<{name}> {{ - fn destruct(self: {name}) nopanic {{ + impl {name}Destruct{generics_impl} of Destruct::<{name}{generics}> {{ + fn destruct(self: {name}{generics}) nopanic {{ {} }} }} ", members.iter().map(|member| { - format!("traits::Destruct::destruct(self.{member});") - }).join("\n ")} + format!("traits::Destruct::destruct(self.{member});") + }).join("\n "), + generics = format_generics(type_generics, other_generics), + generics_impl = format_generics_with_trait(type_generics, other_generics, + |t| format!("impl {t}Destruct: Destruct<{t}>")) + } } ExtraInfo::Extern => unreachable!(), } @@ -238,23 +295,30 @@ fn get_partial_eq_impl(name: &str, extra_info: &ExtraInfo) -> String { ) }).join("\n ")} } - ExtraInfo::Struct(members) => { + ExtraInfo::Struct { members, type_generics, other_generics } => { formatdoc! {" - impl {name}PartialEq of PartialEq::<{name}> {{ + impl {name}PartialEq{generics_impl} of PartialEq::<{name}{generics}> {{ #[inline(always)] - fn eq(lhs: {name}, rhs: {name}) -> bool {{ + fn eq(lhs: {name}{generics}, rhs: {name}{generics}) -> bool {{ {} true }} #[inline(always)] - fn ne(lhs: {name}, rhs: {name}) -> bool {{ + fn ne(lhs: {name}{generics}, rhs: {name}{generics}) -> bool {{ !(lhs == rhs) }} }} ", members.iter().map(|member| { - // TODO(orizi): Use `&&` when supported. - format!("if lhs.{member} != rhs.{member} {{ return false; }}") - }).join("\n ")} + // TODO(orizi): Use `&&` when supported. + format!("if lhs.{member} != rhs.{member} {{ return false; }}") + }).join("\n "), + generics = format_generics(type_generics, other_generics), + // TODO(spapini): Remove the Destruct requirement by changing + // member borrowing logic to recognize snapshots. + generics_impl = format_generics_with_trait(type_generics, other_generics, + |t| format!("impl {t}PartialEq: PartialEq<{t}>, \ + impl {t}Destruct: Destruct<{t}>")) + } } ExtraInfo::Extern => unreachable!(), } @@ -281,7 +345,8 @@ fn get_serde_impl(name: &str, extra_info: &ExtraInfo) -> String { ", variants.iter().enumerate().map(|(idx, variant)| { format!( - "{name}::{variant}(x) => {{ serde::Serde::serialize(@{idx}, ref output); serde::Serde::serialize(x, ref output); }},", + "{name}::{variant}(x) => {{ serde::Serde::serialize(@{idx}, ref output); \ + serde::Serde::serialize(x, ref output); }},", ) }).join("\n "), variants.iter().enumerate().map(|(idx, variant)| { @@ -291,13 +356,13 @@ fn get_serde_impl(name: &str, extra_info: &ExtraInfo) -> String { }).join("\n else "), } } - ExtraInfo::Struct(members) => { + ExtraInfo::Struct { members, type_generics, other_generics } => { formatdoc! {" - impl {name}Serde of serde::Serde::<{name}> {{ - fn serialize(self: @{name}, ref output: array::Array) {{ + impl {name}Serde{generics_impl} of serde::Serde::<{name}{generics}> {{ + fn serialize(self: @{name}{generics}, ref output: array::Array) {{ {} }} - fn deserialize(ref serialized: array::Span) -> Option<{name}> {{ + fn deserialize(ref serialized: array::Span) -> Option<{name}{generics}> {{ Option::Some({name} {{ {} }}) @@ -306,12 +371,24 @@ fn get_serde_impl(name: &str, extra_info: &ExtraInfo) -> String { ", members.iter().map(|member| format!("serde::Serde::serialize(self.{member}, ref output)")).join(";\n "), members.iter().map(|member| format!("{member}: serde::Serde::deserialize(ref serialized)?,")).join("\n "), + generics = format_generics(type_generics, other_generics), + generics_impl = format_generics_with_trait(type_generics, other_generics, + |t| format!("impl {t}Serde: serde::Serde<{t}>, impl {t}Destruct: Destruct<{t}>")) } } ExtraInfo::Extern => unreachable!(), } } -fn get_empty_impl(name: &str, derived_trait: &str) -> String { - format!("impl {name}{derived_trait} of {derived_trait}::<{name}>;\n") +fn get_empty_impl(name: &str, derived_trait: &str, extra_info: &ExtraInfo) -> String { + match extra_info { + ExtraInfo::Struct { type_generics, other_generics, .. } => format!( + "impl {name}{derived_trait}{generics_impl} of {derived_trait}::<{name}{generics}>;\n", + generics = format_generics(type_generics, other_generics), + generics_impl = format_generics_with_trait(type_generics, other_generics, |t| format!( + "impl {t}{derived_trait}: {derived_trait}<{t}>" + )) + ), + _ => format!("impl {name}{derived_trait} of {derived_trait}::<{name}>;\n"), + } } diff --git a/crates/cairo-lang-plugins/src/test_data/derive b/crates/cairo-lang-plugins/src/test_data/derive index 15e6d7eda44..f49a08c7d1c 100644 --- a/crates/cairo-lang-plugins/src/test_data/derive +++ b/crates/cairo-lang-plugins/src/test_data/derive @@ -16,6 +16,20 @@ struct TwoMemberStruct { b: B, } +#[derive(Copy, Destruct)] +struct GenericStruct { + a: T, +} + + +trait SomeTrait {} + +#[derive(Drop, Clone, PartialEq, Serde)] +struct TwoMemberGenericStruct> { + a: T, + b: U, +} + #[derive(Clone, Destruct, PartialEq, Serde)] enum TwoVariantEnum { First: A, @@ -29,15 +43,15 @@ extern type ExternType; #[derive(Copy, Drop)] struct A{} -impl ACopy of Copy::; -impl ADrop of Drop::; +impl ACopy<> of Copy::>; +impl ADrop<> of Drop::>; #[derive(Copy, Drop)] struct B{} -impl BCopy of Copy::; -impl BDrop of Drop::; +impl BCopy<> of Copy::>; +impl BDrop<> of Drop::>; #[derive(Clone, Destruct, PartialEq, Serde)] @@ -46,38 +60,38 @@ struct TwoMemberStruct { b: B, } -impl TwoMemberStructClone of Clone:: { - fn clone(self: @TwoMemberStruct) -> TwoMemberStruct { +impl TwoMemberStructClone<> of Clone::> { + fn clone(self: @TwoMemberStruct<>) -> TwoMemberStruct<> { TwoMemberStruct { a: self.a.clone(), b: self.b.clone(), } } } -impl TwoMemberStructDestruct of Destruct:: { - fn destruct(self: TwoMemberStruct) nopanic { +impl TwoMemberStructDestruct<> of Destruct::> { + fn destruct(self: TwoMemberStruct<>) nopanic { traits::Destruct::destruct(self.a); traits::Destruct::destruct(self.b); } } -impl TwoMemberStructPartialEq of PartialEq:: { +impl TwoMemberStructPartialEq<> of PartialEq::> { #[inline(always)] - fn eq(lhs: TwoMemberStruct, rhs: TwoMemberStruct) -> bool { + fn eq(lhs: TwoMemberStruct<>, rhs: TwoMemberStruct<>) -> bool { if lhs.a != rhs.a { return false; } if lhs.b != rhs.b { return false; } true } #[inline(always)] - fn ne(lhs: TwoMemberStruct, rhs: TwoMemberStruct) -> bool { + fn ne(lhs: TwoMemberStruct<>, rhs: TwoMemberStruct<>) -> bool { !(lhs == rhs) } } -impl TwoMemberStructSerde of serde::Serde:: { - fn serialize(self: @TwoMemberStruct, ref output: array::Array) { +impl TwoMemberStructSerde<> of serde::Serde::> { + fn serialize(self: @TwoMemberStruct<>, ref output: array::Array) { serde::Serde::serialize(self.a, ref output); serde::Serde::serialize(self.b, ref output) } - fn deserialize(ref serialized: array::Span) -> Option { + fn deserialize(ref serialized: array::Span) -> Option> { Option::Some(TwoMemberStruct { a: serde::Serde::deserialize(ref serialized)?, b: serde::Serde::deserialize(ref serialized)?, @@ -86,6 +100,64 @@ impl TwoMemberStructSerde of serde::Serde:: { } +#[derive(Copy, Destruct)] +struct GenericStruct { + a: T, +} + +impl GenericStructCopy> of Copy::>; +impl GenericStructDestruct> of Destruct::> { + fn destruct(self: GenericStruct) nopanic { + traits::Destruct::destruct(self.a); + } +} + + + +trait SomeTrait {} + + +#[derive(Drop, Clone, PartialEq, Serde)] +struct TwoMemberGenericStruct> { + a: T, + b: U, +} + +impl TwoMemberGenericStructDrop, impl TDrop: Drop, impl UDrop: Drop> of Drop::, >>; +impl TwoMemberGenericStructClone, impl TClone: Clone, impl TDestruct: Destruct, impl UClone: Clone, impl UDestruct: Destruct> of Clone::, >> { + fn clone(self: @TwoMemberGenericStruct, >) -> TwoMemberGenericStruct, > { + TwoMemberGenericStruct { + a: self.a.clone(), + b: self.b.clone(), + } + } +} +impl TwoMemberGenericStructPartialEq, impl TPartialEq: PartialEq, impl TDestruct: Destruct, impl UPartialEq: PartialEq, impl UDestruct: Destruct> of PartialEq::, >> { + #[inline(always)] + fn eq(lhs: TwoMemberGenericStruct, >, rhs: TwoMemberGenericStruct, >) -> bool { + if lhs.a != rhs.a { return false; } + if lhs.b != rhs.b { return false; } + true + } + #[inline(always)] + fn ne(lhs: TwoMemberGenericStruct, >, rhs: TwoMemberGenericStruct, >) -> bool { + !(lhs == rhs) + } +} +impl TwoMemberGenericStructSerde, impl TSerde: serde::Serde, impl TDestruct: Destruct, impl USerde: serde::Serde, impl UDestruct: Destruct> of serde::Serde::, >> { + fn serialize(self: @TwoMemberGenericStruct, >, ref output: array::Array) { + serde::Serde::serialize(self.a, ref output); + serde::Serde::serialize(self.b, ref output) + } + fn deserialize(ref serialized: array::Span) -> Option, >> { + Option::Some(TwoMemberGenericStruct { + a: serde::Serde::deserialize(ref serialized)?, + b: serde::Serde::deserialize(ref serialized)?, + }) + } +} + + #[derive(Clone, Destruct, PartialEq, Serde)] enum TwoVariantEnum { First: A, diff --git a/tests/bug_samples/issue2964.cairo b/tests/bug_samples/issue2964.cairo new file mode 100644 index 00000000000..4831b9eb1b1 --- /dev/null +++ b/tests/bug_samples/issue2964.cairo @@ -0,0 +1,30 @@ +use serde::Serde; +use clone::Clone; +use array::ArrayTrait; +use option::OptionTrait; + +#[derive(Copy, Drop, Serde, PartialEq)] +struct SimpleStruct { + x: felt252, + y: felt252, +} + +#[derive(Copy, Clone, Destruct, Serde, PartialEq)] +struct GenericStruct { + x: T, + y: U, +} + +#[test] +fn main() { + // This assumes that Drop implies Destruct and Copy implies Clone + let mut a = GenericStruct { x: SimpleStruct { x: 1, y: 2 }, y: SimpleStruct { x: 1, y: 2 } }; + a.x.x = 34; + a.y.y = 5; + let mut serialized = ArrayTrait::::new(); + a.serialize(ref serialized); + let mut as_span = serialized.span(); + let deserialized = serde::Serde::>::deserialize(ref as_span).unwrap(); + assert(a == deserialized, 'Bad Serde'); +} diff --git a/tests/bug_samples/lib.cairo b/tests/bug_samples/lib.cairo index 8cd37629162..3ed64755f4b 100644 --- a/tests/bug_samples/lib.cairo +++ b/tests/bug_samples/lib.cairo @@ -14,6 +14,7 @@ mod issue2820; mod issue2932; mod issue2939; mod issue2961; +mod issue2964; mod loop_only_change; mod inconsistent_gas; mod partial_param_local;