Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom derives for specific generated types #520

Merged
merged 24 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use std::{
path::PathBuf,
};
use structopt::StructOpt;
use subxt_codegen::GeneratedTypeDerives;
use subxt_codegen::DerivesRegistry;

/// Utilities for working with substrate metadata for subxt.
#[derive(Debug, StructOpt)]
Expand Down Expand Up @@ -163,8 +163,8 @@ fn codegen<I: Input>(
.iter()
.map(|raw| syn::parse_str(raw))
.collect::<Result<Vec<_>, _>>()?;
let mut derives = GeneratedTypeDerives::default();
derives.append(p.into_iter());
let mut derives = DerivesRegistry::default();
derives.extend_for_all(p.into_iter());

let runtime_api = generator.generate_runtime(item_mod, derives);
println!("{}", runtime_api);
Expand Down
20 changes: 7 additions & 13 deletions codegen/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mod errors;
mod events;
mod storage;

use super::GeneratedTypeDerives;
use super::DerivesRegistry;
use crate::{
ir,
types::{
Expand Down Expand Up @@ -66,15 +66,12 @@ use std::{
path,
string::ToString,
};
use syn::{
parse_quote,
punctuated::Punctuated,
};
use syn::parse_quote;

pub fn generate_runtime_api<P>(
item_mod: syn::ItemMod,
path: P,
generated_type_derives: Option<Punctuated<syn::Path, syn::Token![,]>>,
derives: DerivesRegistry,
) -> TokenStream2
where
P: AsRef<path::Path>,
Expand All @@ -90,11 +87,6 @@ where
let metadata = frame_metadata::RuntimeMetadataPrefixed::decode(&mut &bytes[..])
.unwrap_or_else(|e| abort_call_site!("Failed to decode metadata: {}", e));

let mut derives = GeneratedTypeDerives::default();
if let Some(user_derives) = generated_type_derives {
derives.append(user_derives.iter().cloned())
}

let generator = RuntimeGenerator::new(metadata);
generator.generate_runtime(item_mod, derives)
}
Expand All @@ -114,9 +106,10 @@ impl RuntimeGenerator {
pub fn generate_runtime(
&self,
item_mod: syn::ItemMod,
derives: GeneratedTypeDerives,
derives: DerivesRegistry,
) -> TokenStream2 {
let item_mod_ir = ir::ItemMod::from(item_mod);
let default_derives = derives.default_derives();

// some hardcoded default type substitutes, can be overridden by user
let mut type_substitutes = [
Expand Down Expand Up @@ -237,7 +230,7 @@ impl RuntimeGenerator {
});

let outer_event = quote! {
#derives
#default_derives
pub enum Event {
#( #outer_event_variants )*
}
Expand Down Expand Up @@ -407,6 +400,7 @@ where
type_gen,
);
CompositeDef::struct_def(
&ty,
struct_name.as_ref(),
Default::default(),
fields,
Expand Down
3 changes: 2 additions & 1 deletion codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ pub use self::{
RuntimeGenerator,
},
types::{
GeneratedTypeDerives,
Derives,
DerivesRegistry,
Module,
TypeGenerator,
},
Expand Down
9 changes: 6 additions & 3 deletions codegen/src/types/composite_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// along with subxt. If not, see <http://www.gnu.org/licenses/>.

use super::{
Derives,
Field,
GeneratedTypeDerives,
TypeDefParameters,
TypeGenerator,
TypeParameter,
Expand All @@ -29,6 +29,8 @@ use quote::{
quote,
};
use scale_info::{
form::PortableForm,
Type,
TypeDef,
TypeDefPrimitive,
};
Expand All @@ -52,14 +54,15 @@ pub struct CompositeDef {
impl CompositeDef {
/// Construct a definition which will generate code for a standalone `struct`.
pub fn struct_def(
ty: &Type<PortableForm>,
ident: &str,
type_params: TypeDefParameters,
fields_def: CompositeDefFields,
field_visibility: Option<syn::Visibility>,
type_gen: &TypeGenerator,
docs: &[String],
) -> Self {
let mut derives = type_gen.derives().clone();
let mut derives = type_gen.type_derives(ty);
let fields: Vec<_> = fields_def.field_types().collect();

if fields.len() == 1 {
Expand Down Expand Up @@ -165,7 +168,7 @@ impl quote::ToTokens for CompositeDef {
pub enum CompositeDefKind {
/// Composite type comprising a Rust `struct`.
Struct {
derives: GeneratedTypeDerives,
derives: Derives,
type_params: TypeDefParameters,
field_visibility: Option<syn::Visibility>,
},
Expand Down
91 changes: 75 additions & 16 deletions codegen/src/types/derives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,107 @@
use syn::{
parse_quote,
punctuated::Punctuated,
Path,
};

use std::collections::{
HashMap,
HashSet,
};

#[derive(Debug, Default, Clone)]
pub struct DerivesRegistry {
default_derives: Derives,
specific_type_derives: HashMap<syn::TypePath, Derives>,
}

impl DerivesRegistry {
/// Insert derives to be applied to all generated types.
pub fn extend_for_all(&mut self, derives: impl Iterator<Item = syn::Path>) {
self.default_derives.derives.extend(derives)
}

/// Insert derives to be applied to a specific generated type.
pub fn extend_for_type(
&mut self,
ty: syn::TypePath,
derives: impl Iterator<Item = syn::Path>,
) {
let type_derives = self
.specific_type_derives
.entry(ty)
.or_insert_with(Derives::default);
type_derives.derives.extend(derives)
}

/// Returns a the derives to be applied to all generated types.
pub fn default_derives(&self) -> &Derives {
&self.default_derives
}

/// Resolve the derives for a generated type. Includes:
/// - The default derives for all types e.g. `scale::Encode, scale::Decode`
/// - Any user-defined derives for all types via `generated_type_derives`
/// - Any user-defined derives for this specific type
pub fn resolve(&self, ty: &syn::TypePath) -> Derives {
let mut defaults = self.default_derives.derives.clone();
if let Some(specific) = self.specific_type_derives.get(ty) {
defaults.extend(specific.derives.iter().cloned());
}
Derives { derives: defaults }
}
}

#[derive(Debug, Clone)]
pub struct GeneratedTypeDerives {
derives: Punctuated<syn::Path, syn::Token![,]>,
pub struct Derives {
derives: HashSet<syn::Path>,
}

impl GeneratedTypeDerives {
pub fn new(derives: Punctuated<syn::Path, syn::Token!(,)>) -> Self {
impl FromIterator<syn::Path> for Derives {
fn from_iter<T: IntoIterator<Item = Path>>(iter: T) -> Self {
let derives = iter.into_iter().collect();
Self { derives }
}
}

impl Derives {
/// Add `::subxt::codec::CompactAs` to the derives.
pub fn push_codec_compact_as(&mut self) {
self.derives.push(parse_quote!(::subxt::codec::CompactAs));
self.insert(parse_quote!(::subxt::codec::CompactAs));
}

pub fn append(&mut self, derives: impl Iterator<Item = syn::Path>) {
for derive in derives {
self.derives.push(derive)
self.insert(derive)
}
}

pub fn push(&mut self, derive: syn::Path) {
self.derives.push(derive);
pub fn insert(&mut self, derive: syn::Path) {
self.derives.insert(derive);
}
}

impl Default for GeneratedTypeDerives {
impl Default for Derives {
fn default() -> Self {
let mut derives = Punctuated::new();
derives.push(syn::parse_quote!(::subxt::codec::Encode));
derives.push(syn::parse_quote!(::subxt::codec::Decode));
derives.push(syn::parse_quote!(Debug));
Self::new(derives)
let mut derives = HashSet::new();
derives.insert(syn::parse_quote!(::subxt::codec::Encode));
derives.insert(syn::parse_quote!(::subxt::codec::Decode));
derives.insert(syn::parse_quote!(Debug));
Self { derives }
}
}

impl quote::ToTokens for GeneratedTypeDerives {
impl quote::ToTokens for Derives {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
if !self.derives.is_empty() {
let derives = &self.derives;
let mut sorted = self.derives.iter().cloned().collect::<Vec<_>>();
sorted.sort_by(|a, b| {
quote::quote!(#a)
.to_string()
.cmp(&quote::quote!(#b).to_string())
});
let derives: Punctuated<syn::Path, syn::Token![,]> =
sorted.iter().cloned().collect();
tokens.extend(quote::quote! {
#[derive(#derives)]
})
Expand Down
39 changes: 26 additions & 13 deletions codegen/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use proc_macro2::{
Span,
TokenStream,
};
use proc_macro_error::abort_call_site;
use quote::{
quote,
ToTokens,
Expand All @@ -48,7 +49,10 @@ pub use self::{
CompositeDefFieldType,
CompositeDefFields,
},
derives::GeneratedTypeDerives,
derives::{
Derives,
DerivesRegistry,
},
type_def::TypeDefGen,
type_def_params::TypeDefParameters,
type_path::{
Expand All @@ -71,7 +75,7 @@ pub struct TypeGenerator<'a> {
/// User defined overrides for generated types.
type_substitutes: HashMap<String, syn::TypePath>,
/// Set of derives with which to annotate generated types.
derives: GeneratedTypeDerives,
derives: DerivesRegistry,
}

impl<'a> TypeGenerator<'a> {
Expand All @@ -80,7 +84,7 @@ impl<'a> TypeGenerator<'a> {
type_registry: &'a PortableRegistry,
root_mod: &'static str,
type_substitutes: HashMap<String, syn::TypePath>,
derives: GeneratedTypeDerives,
derives: DerivesRegistry,
) -> Self {
let root_mod_ident = Ident::new(root_mod, Span::call_site());
Self {
Expand All @@ -92,7 +96,7 @@ impl<'a> TypeGenerator<'a> {
}

/// Generate a module containing all types defined in the supplied type registry.
pub fn generate_types_mod(&'a self) -> Module<'a> {
pub fn generate_types_mod(&self) -> Module {
let mut root_mod =
Module::new(self.types_mod_ident.clone(), self.types_mod_ident.clone());

Expand All @@ -119,7 +123,7 @@ impl<'a> TypeGenerator<'a> {
id: u32,
path: Vec<String>,
root_mod_ident: &Ident,
module: &mut Module<'a>,
module: &mut Module,
) {
let joined_path = path.join("::");
if self.type_substitutes.contains_key(&joined_path) {
Expand Down Expand Up @@ -215,22 +219,31 @@ impl<'a> TypeGenerator<'a> {
}
}

/// Returns the derives with which all generated type will be decorated.
pub fn derives(&self) -> &GeneratedTypeDerives {
&self.derives
/// Returns a the derives to be applied to all generated types.
pub fn default_derives(&self) -> &Derives {
self.derives.default_derives()
}

/// Returns a the derives to be applied to a generated type.
pub fn type_derives(&self, ty: &Type<PortableForm>) -> Derives {
let joined_path = ty.path().segments().join("::");
let ty_path: syn::TypePath = syn::parse_str(&joined_path).unwrap_or_else(|e| {
abort_call_site!("'{}' is an invalid type path: {:?}", joined_path, e,)
});
self.derives.resolve(&ty_path)
}
}

/// Represents a Rust `mod`, containing generated types and child `mod`s.
#[derive(Debug)]
pub struct Module<'a> {
pub struct Module {
name: Ident,
root_mod: Ident,
children: BTreeMap<Ident, Module<'a>>,
types: BTreeMap<scale_info::Path<scale_info::form::PortableForm>, TypeDefGen<'a>>,
children: BTreeMap<Ident, Module>,
types: BTreeMap<scale_info::Path<scale_info::form::PortableForm>, TypeDefGen>,
}

impl<'a> ToTokens for Module<'a> {
impl ToTokens for Module {
fn to_tokens(&self, tokens: &mut TokenStream) {
let name = &self.name;
let root_mod = &self.root_mod;
Expand All @@ -248,7 +261,7 @@ impl<'a> ToTokens for Module<'a> {
}
}

impl<'a> Module<'a> {
impl Module {
/// Create a new [`Module`], with a reference to the root `mod` for resolving type paths.
pub(crate) fn new(name: Ident, root_mod: Ident) -> Self {
Self {
Expand Down
Loading