Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support generic types in Derive statements #2964

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 106 additions & 29 deletions crates/cairo-lang-plugins/src/plugins/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use cairo_lang_semantic::plugin::{AsDynMacroPlugin, SemanticPlugin, TrivialPlugi
use cairo_lang_syntax::attribute::structured::{
AttributeArg, AttributeArgVariant, AttributeStructurize,
};
use cairo_lang_syntax::node::ast::{AttributeList, MemberList};
use cairo_lang_syntax::node::ast::{
AttributeList, ItemStruct, MemberList, OptionWrappedGenericParamList,
};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::QueryAttrs;
use cairo_lang_syntax::node::{ast, Terminal};
use cairo_lang_syntax::node::{ast, Terminal, TypedSyntaxNode};
use indoc::formatdoc;
use itertools::Itertools;
use smol_str::SmolStr;
Expand All @@ -26,7 +28,7 @@ impl MacroPlugin for DerivePlugin {
db,
struct_ast.name(db),
struct_ast.attributes(db),
ExtraInfo::Struct(member_names(db, struct_ast.members(db))),
extract_struct_extra_info(db, &struct_ast),
),
ast::Item::Enum(enum_ast) => generate_derive_code_for_type(
db,
Expand Down Expand Up @@ -56,14 +58,61 @@ impl SemanticPlugin for DerivePlugin {}

enum ExtraInfo {
Enum(Vec<SmolStr>),
Struct(Vec<SmolStr>),
Struct { members: Vec<SmolStr>, type_generics: Vec<SmolStr>, other_generics: Vec<String> },
Extern,
}

fn member_names(db: &dyn SyntaxGroup, members: MemberList) -> Vec<SmolStr> {
members.elements(db).into_iter().map(|member| member.name(db).text(db)).collect()
}

fn extract_struct_extra_info(db: &dyn SyntaxGroup, struct_ast: &ItemStruct) -> ExtraInfo {
let members = member_names(db, struct_ast.members(db));
let mut type_generics = vec![];
let mut other_generics = vec![];
match struct_ast.generic_params(db) {
OptionWrappedGenericParamList::WrappedGenericParamList(gens) => gens
.generic_params(db)
.elements(db)
.into_iter()
.map(|member| match member {
ast::GenericParam::Type(t) => {
type_generics.push(t.name(db).text(db));
}
ast::GenericParam::Impl(i) => {
other_generics.push(i.as_syntax_node().get_text_without_trivia(db))
}
ast::GenericParam::Const(c) => {
other_generics.push(c.as_syntax_node().get_text_without_trivia(db))
}
})
.collect(),
OptionWrappedGenericParamList::Empty(_) => vec![],
};
ExtraInfo::Struct { members, type_generics, other_generics }
}

fn format_generics_with_trait(
type_generics: &[SmolStr],
other_generics: &[String],
f: impl Fn(&SmolStr) -> String,
) -> String {
format!(
"<{}{}{}>",
type_generics.iter().map(|s| format!("{}, ", s)).collect::<String>(),
other_generics.iter().map(|s| format!("{}, ", s)).collect::<String>(),
type_generics.iter().map(f).join(", "),
)
}

fn format_generics(type_generics: &[SmolStr], other_generics: &[String]) -> String {
format!(
"<{}{}>",
type_generics.iter().map(|s| format!("{}, ", s)).collect::<String>(),
other_generics.iter().map(|s| format!("{}, ", s)).collect::<String>(),
)
}

/// Adds an implementation for all requested derives for the type.
fn generate_derive_code_for_type(
db: &dyn SyntaxGroup,
Expand Down Expand Up @@ -107,7 +156,7 @@ fn generate_derive_code_for_type(
let name = ident.text(db);
let derived = segment.ident(db).text(db);
match derived.as_str() {
"Copy" | "Drop" => impls.push(get_empty_impl(&name, &derived)),
"Copy" | "Drop" => impls.push(get_empty_impl(&name, &derived, &extra_info)),
"Clone" if !matches!(extra_info, ExtraInfo::Extern) => {
impls.push(get_clone_impl(&name, &extra_info))
}
Expand Down Expand Up @@ -163,18 +212,22 @@ fn get_clone_impl(name: &str, extra_info: &ExtraInfo) -> String {
format!("{name}::{variant}(x) => {name}::{variant}(x.clone()),")
}).join("\n ")}
}
ExtraInfo::Struct(members) => {
ExtraInfo::Struct { members, type_generics, other_generics } => {
formatdoc! {"
impl {name}Clone of Clone::<{name}> {{
fn clone(self: @{name}) -> {name} {{
impl {name}Clone{generics_impl} of Clone::<{name}{generics}> {{
fn clone(self: @{name}{generics}) -> {name}{generics} {{
{name} {{
{}
}}
}}
}}
", members.iter().map(|member| {
format!("{member}: self.{member}.clone(),")
}).join("\n ")}
format!("{member}: self.{member}.clone(),")
}).join("\n "),
generics = format_generics(type_generics, other_generics),
generics_impl = format_generics_with_trait(type_generics, other_generics,
|t| format!("impl {t}Clone: Clone<{t}>, impl {t}Destruct: Destruct<{t}>"))
}
}
ExtraInfo::Extern => unreachable!(),
}
Expand All @@ -195,16 +248,20 @@ fn get_destruct_impl(name: &str, extra_info: &ExtraInfo) -> String {
format!("{name}::{variant}(x) => traits::Destruct::destruct(x),")
}).join("\n ")}
}
ExtraInfo::Struct(members) => {
ExtraInfo::Struct { members, type_generics, other_generics } => {
formatdoc! {"
impl {name}Destruct of Destruct::<{name}> {{
fn destruct(self: {name}) nopanic {{
impl {name}Destruct{generics_impl} of Destruct::<{name}{generics}> {{
fn destruct(self: {name}{generics}) nopanic {{
{}
}}
}}
", members.iter().map(|member| {
format!("traits::Destruct::destruct(self.{member});")
}).join("\n ")}
format!("traits::Destruct::destruct(self.{member});")
}).join("\n "),
generics = format_generics(type_generics, other_generics),
generics_impl = format_generics_with_trait(type_generics, other_generics,
|t| format!("impl {t}Destruct: Destruct<{t}>"))
}
}
ExtraInfo::Extern => unreachable!(),
}
Expand Down Expand Up @@ -238,23 +295,30 @@ fn get_partial_eq_impl(name: &str, extra_info: &ExtraInfo) -> String {
)
}).join("\n ")}
}
ExtraInfo::Struct(members) => {
ExtraInfo::Struct { members, type_generics, other_generics } => {
formatdoc! {"
impl {name}PartialEq of PartialEq::<{name}> {{
impl {name}PartialEq{generics_impl} of PartialEq::<{name}{generics}> {{
#[inline(always)]
fn eq(lhs: {name}, rhs: {name}) -> bool {{
fn eq(lhs: {name}{generics}, rhs: {name}{generics}) -> bool {{
{}
true
}}
#[inline(always)]
fn ne(lhs: {name}, rhs: {name}) -> bool {{
fn ne(lhs: {name}{generics}, rhs: {name}{generics}) -> bool {{
!(lhs == rhs)
}}
}}
", members.iter().map(|member| {
// TODO(orizi): Use `&&` when supported.
format!("if lhs.{member} != rhs.{member} {{ return false; }}")
}).join("\n ")}
// TODO(orizi): Use `&&` when supported.
format!("if lhs.{member} != rhs.{member} {{ return false; }}")
}).join("\n "),
generics = format_generics(type_generics, other_generics),
// TODO(spapini): Remove the Destruct requirement by changing
// member borrowing logic to recognize snapshots.
generics_impl = format_generics_with_trait(type_generics, other_generics,
|t| format!("impl {t}PartialEq: PartialEq<{t}>, \
impl {t}Destruct: Destruct<{t}>"))
}
}
ExtraInfo::Extern => unreachable!(),
}
Expand All @@ -281,7 +345,8 @@ fn get_serde_impl(name: &str, extra_info: &ExtraInfo) -> String {
",
variants.iter().enumerate().map(|(idx, variant)| {
format!(
"{name}::{variant}(x) => {{ serde::Serde::serialize(@{idx}, ref output); serde::Serde::serialize(x, ref output); }},",
"{name}::{variant}(x) => {{ serde::Serde::serialize(@{idx}, ref output); \
serde::Serde::serialize(x, ref output); }},",
)
}).join("\n "),
variants.iter().enumerate().map(|(idx, variant)| {
Expand All @@ -291,13 +356,13 @@ fn get_serde_impl(name: &str, extra_info: &ExtraInfo) -> String {
}).join("\n else "),
}
}
ExtraInfo::Struct(members) => {
ExtraInfo::Struct { members, type_generics, other_generics } => {
formatdoc! {"
impl {name}Serde of serde::Serde::<{name}> {{
fn serialize(self: @{name}, ref output: array::Array<felt252>) {{
impl {name}Serde{generics_impl} of serde::Serde::<{name}{generics}> {{
fn serialize(self: @{name}{generics}, ref output: array::Array<felt252>) {{
{}
}}
fn deserialize(ref serialized: array::Span<felt252>) -> Option<{name}> {{
fn deserialize(ref serialized: array::Span<felt252>) -> Option<{name}{generics}> {{
Option::Some({name} {{
{}
}})
Expand All @@ -306,12 +371,24 @@ fn get_serde_impl(name: &str, extra_info: &ExtraInfo) -> String {
",
members.iter().map(|member| format!("serde::Serde::serialize(self.{member}, ref output)")).join(";\n "),
members.iter().map(|member| format!("{member}: serde::Serde::deserialize(ref serialized)?,")).join("\n "),
generics = format_generics(type_generics, other_generics),
generics_impl = format_generics_with_trait(type_generics, other_generics,
|t| format!("impl {t}Serde: serde::Serde<{t}>, impl {t}Destruct: Destruct<{t}>"))
}
}
ExtraInfo::Extern => unreachable!(),
}
}

fn get_empty_impl(name: &str, derived_trait: &str) -> String {
format!("impl {name}{derived_trait} of {derived_trait}::<{name}>;\n")
fn get_empty_impl(name: &str, derived_trait: &str, extra_info: &ExtraInfo) -> String {
match extra_info {
ExtraInfo::Struct { type_generics, other_generics, .. } => format!(
"impl {name}{derived_trait}{generics_impl} of {derived_trait}::<{name}{generics}>;\n",
generics = format_generics(type_generics, other_generics),
generics_impl = format_generics_with_trait(type_generics, other_generics, |t| format!(
"impl {t}{derived_trait}: {derived_trait}<{t}>"
))
),
_ => format!("impl {name}{derived_trait} of {derived_trait}::<{name}>;\n"),
}
}
100 changes: 86 additions & 14 deletions crates/cairo-lang-plugins/src/test_data/derive
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ struct TwoMemberStruct {
b: B,
}

#[derive(Copy, Destruct)]
struct GenericStruct<T> {
a: T,
}


trait SomeTrait<T, U> {}

#[derive(Drop, Clone, PartialEq, Serde)]
struct TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>> {
a: T,
b: U,
}

#[derive(Clone, Destruct, PartialEq, Serde)]
enum TwoVariantEnum {
First: A,
Expand All @@ -29,15 +43,15 @@ extern type ExternType;
#[derive(Copy, Drop)]
struct A{}

impl ACopy of Copy::<A>;
impl ADrop of Drop::<A>;
impl ACopy<> of Copy::<A<>>;
impl ADrop<> of Drop::<A<>>;


#[derive(Copy, Drop)]
struct B{}

impl BCopy of Copy::<B>;
impl BDrop of Drop::<B>;
impl BCopy<> of Copy::<B<>>;
impl BDrop<> of Drop::<B<>>;


#[derive(Clone, Destruct, PartialEq, Serde)]
Expand All @@ -46,38 +60,38 @@ struct TwoMemberStruct {
b: B,
}

impl TwoMemberStructClone of Clone::<TwoMemberStruct> {
fn clone(self: @TwoMemberStruct) -> TwoMemberStruct {
impl TwoMemberStructClone<> of Clone::<TwoMemberStruct<>> {
fn clone(self: @TwoMemberStruct<>) -> TwoMemberStruct<> {
TwoMemberStruct {
a: self.a.clone(),
b: self.b.clone(),
}
}
}
impl TwoMemberStructDestruct of Destruct::<TwoMemberStruct> {
fn destruct(self: TwoMemberStruct) nopanic {
impl TwoMemberStructDestruct<> of Destruct::<TwoMemberStruct<>> {
fn destruct(self: TwoMemberStruct<>) nopanic {
traits::Destruct::destruct(self.a);
traits::Destruct::destruct(self.b);
}
}
impl TwoMemberStructPartialEq of PartialEq::<TwoMemberStruct> {
impl TwoMemberStructPartialEq<> of PartialEq::<TwoMemberStruct<>> {
#[inline(always)]
fn eq(lhs: TwoMemberStruct, rhs: TwoMemberStruct) -> bool {
fn eq(lhs: TwoMemberStruct<>, rhs: TwoMemberStruct<>) -> bool {
if lhs.a != rhs.a { return false; }
if lhs.b != rhs.b { return false; }
true
}
#[inline(always)]
fn ne(lhs: TwoMemberStruct, rhs: TwoMemberStruct) -> bool {
fn ne(lhs: TwoMemberStruct<>, rhs: TwoMemberStruct<>) -> bool {
!(lhs == rhs)
}
}
impl TwoMemberStructSerde of serde::Serde::<TwoMemberStruct> {
fn serialize(self: @TwoMemberStruct, ref output: array::Array<felt252>) {
impl TwoMemberStructSerde<> of serde::Serde::<TwoMemberStruct<>> {
fn serialize(self: @TwoMemberStruct<>, ref output: array::Array<felt252>) {
serde::Serde::serialize(self.a, ref output);
serde::Serde::serialize(self.b, ref output)
}
fn deserialize(ref serialized: array::Span<felt252>) -> Option<TwoMemberStruct> {
fn deserialize(ref serialized: array::Span<felt252>) -> Option<TwoMemberStruct<>> {
Option::Some(TwoMemberStruct {
a: serde::Serde::deserialize(ref serialized)?,
b: serde::Serde::deserialize(ref serialized)?,
Expand All @@ -86,6 +100,64 @@ impl TwoMemberStructSerde of serde::Serde::<TwoMemberStruct> {
}


#[derive(Copy, Destruct)]
struct GenericStruct<T> {
a: T,
}

impl GenericStructCopy<T, impl TCopy: Copy<T>> of Copy::<GenericStruct<T, >>;
impl GenericStructDestruct<T, impl TDestruct: Destruct<T>> of Destruct::<GenericStruct<T, >> {
fn destruct(self: GenericStruct<T, >) nopanic {
traits::Destruct::destruct(self.a);
}
}



trait SomeTrait<T, U> {}


#[derive(Drop, Clone, PartialEq, Serde)]
struct TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>> {
a: T,
b: U,
}

impl TwoMemberGenericStructDrop<T, U, impl USomeTrait: SomeTrait<U, T>, impl TDrop: Drop<T>, impl UDrop: Drop<U>> of Drop::<TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >>;
impl TwoMemberGenericStructClone<T, U, impl USomeTrait: SomeTrait<U, T>, impl TClone: Clone<T>, impl TDestruct: Destruct<T>, impl UClone: Clone<U>, impl UDestruct: Destruct<U>> of Clone::<TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >> {
fn clone(self: @TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >) -> TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, > {
TwoMemberGenericStruct {
a: self.a.clone(),
b: self.b.clone(),
}
}
}
impl TwoMemberGenericStructPartialEq<T, U, impl USomeTrait: SomeTrait<U, T>, impl TPartialEq: PartialEq<T>, impl TDestruct: Destruct<T>, impl UPartialEq: PartialEq<U>, impl UDestruct: Destruct<U>> of PartialEq::<TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >> {
#[inline(always)]
fn eq(lhs: TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >, rhs: TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >) -> bool {
if lhs.a != rhs.a { return false; }
if lhs.b != rhs.b { return false; }
true
}
#[inline(always)]
fn ne(lhs: TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >, rhs: TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >) -> bool {
!(lhs == rhs)
}
}
impl TwoMemberGenericStructSerde<T, U, impl USomeTrait: SomeTrait<U, T>, impl TSerde: serde::Serde<T>, impl TDestruct: Destruct<T>, impl USerde: serde::Serde<U>, impl UDestruct: Destruct<U>> of serde::Serde::<TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >> {
fn serialize(self: @TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >, ref output: array::Array<felt252>) {
serde::Serde::serialize(self.a, ref output);
serde::Serde::serialize(self.b, ref output)
}
fn deserialize(ref serialized: array::Span<felt252>) -> Option<TwoMemberGenericStruct<T, U, impl USomeTrait: SomeTrait<U, T>, >> {
Option::Some(TwoMemberGenericStruct {
a: serde::Serde::deserialize(ref serialized)?,
b: serde::Serde::deserialize(ref serialized)?,
})
}
}


#[derive(Clone, Destruct, PartialEq, Serde)]
enum TwoVariantEnum {
First: A,
Expand Down
Loading