Skip to content

Commit

Permalink
Add custom derives for specific generated types (#520)
Browse files Browse the repository at this point in the history
* WIP implement custom derives per type

* WIP wiring up specific type derives

* Fmt

* Rename GeneratedTypeDerives to Derives

* Fmt

* Fix errors

* Fix test runtime

* Make derives appear in alphabetic order

* Clippy

* Add derive_for_type attribute to example

* Add docs to example

* Rename GeneratedTypeDerive

* Rename ty to type in attribute

* Update darling

* Update codegen/src/types/derives.rs

Co-authored-by: Alexandru Vasile <[email protected]>

* Update codegen/src/types/mod.rs

Co-authored-by: Alexandru Vasile <[email protected]>

* Update codegen/src/types/mod.rs

Co-authored-by: Alexandru Vasile <[email protected]>

* review: update method name

* Add unit tests for combined derives

* Remove out of date docs

* Add macro usage docs

Co-authored-by: Alexandru Vasile <[email protected]>
  • Loading branch information
ascjones and lexnv authored Apr 28, 2022
1 parent 1fd1eee commit 24317b4
Show file tree
Hide file tree
Showing 13 changed files with 372 additions and 118 deletions.
6 changes: 3 additions & 3 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use std::{
path::PathBuf,
};
use structopt::StructOpt;
use subxt_codegen::GeneratedTypeDerives;
use subxt_codegen::DerivesRegistry;
use subxt_metadata::{
get_metadata_hash,
get_pallet_hash,
Expand Down Expand Up @@ -288,8 +288,8 @@ fn codegen<I: Input>(
.iter()
.map(|raw| syn::parse_str(raw))
.collect::<Result<Vec<_>, _>>()?;
let mut derives = GeneratedTypeDerives::default();
derives.append(p.into_iter());
let mut derives = DerivesRegistry::default();
derives.extend_for_all(p.into_iter());

let runtime_api = generator.generate_runtime(item_mod, derives);
println!("{}", runtime_api);
Expand Down
2 changes: 1 addition & 1 deletion codegen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ description = "Generate an API for interacting with a substrate node from FRAME
[dependencies]
async-trait = "0.1.49"
codec = { package = "parity-scale-codec", version = "3.0.0", default-features = false, features = ["derive", "full", "bit-vec"] }
darling = "0.13.0"
darling = "0.14.0"
frame-metadata = "15.0.0"
heck = "0.4.0"
proc-macro2 = "1.0.24"
Expand Down
35 changes: 7 additions & 28 deletions codegen/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,6 @@
// along with subxt. If not, see <http://www.gnu.org/licenses/>.

//! Generate code for submitting extrinsics and query storage of a Substrate runtime.
//!
//! ## Note
//!
//! By default the codegen will search for the `System` pallet's `Account` storage item, which is
//! the conventional location where an account's index (aka nonce) is stored.
//!
//! If this `System::Account` storage item is discovered, then it is assumed that:
//!
//! 1. The type of the storage item is a `struct` (aka a composite type)
//! 2. There exists a field called `nonce` which contains the account index.
//!
//! These assumptions are based on the fact that the `frame_system::AccountInfo` type is the default
//! configured type, and that the vast majority of chain configurations will use this.
//!
//! If either of these conditions are not satisfied, the codegen will fail.
mod calls;
mod constants;
Expand All @@ -39,7 +24,7 @@ mod storage;

use subxt_metadata::get_metadata_per_pallet_hash;

use super::GeneratedTypeDerives;
use super::DerivesRegistry;
use crate::{
ir,
types::{
Expand Down Expand Up @@ -68,15 +53,12 @@ use std::{
path,
string::ToString,
};
use syn::{
parse_quote,
punctuated::Punctuated,
};
use syn::parse_quote;

pub fn generate_runtime_api<P>(
item_mod: syn::ItemMod,
path: P,
generated_type_derives: Option<Punctuated<syn::Path, syn::Token![,]>>,
derives: DerivesRegistry,
) -> TokenStream2
where
P: AsRef<path::Path>,
Expand All @@ -92,11 +74,6 @@ where
let metadata = frame_metadata::RuntimeMetadataPrefixed::decode(&mut &bytes[..])
.unwrap_or_else(|e| abort_call_site!("Failed to decode metadata: {}", e));

let mut derives = GeneratedTypeDerives::default();
if let Some(user_derives) = generated_type_derives {
derives.append(user_derives.iter().cloned())
}

let generator = RuntimeGenerator::new(metadata);
generator.generate_runtime(item_mod, derives)
}
Expand All @@ -116,9 +93,10 @@ impl RuntimeGenerator {
pub fn generate_runtime(
&self,
item_mod: syn::ItemMod,
derives: GeneratedTypeDerives,
derives: DerivesRegistry,
) -> TokenStream2 {
let item_mod_ir = ir::ItemMod::from(item_mod);
let default_derives = derives.default_derives();

// some hardcoded default type substitutes, can be overridden by user
let mut type_substitutes = [
Expand Down Expand Up @@ -265,7 +243,7 @@ impl RuntimeGenerator {
});

let outer_event = quote! {
#derives
#default_derives
pub enum Event {
#( #outer_event_variants )*
}
Expand Down Expand Up @@ -448,6 +426,7 @@ where
type_gen,
);
let struct_def = CompositeDef::struct_def(
&ty,
struct_name.as_ref(),
Default::default(),
fields,
Expand Down
3 changes: 2 additions & 1 deletion codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ pub use self::{
RuntimeGenerator,
},
types::{
GeneratedTypeDerives,
Derives,
DerivesRegistry,
Module,
TypeGenerator,
},
Expand Down
11 changes: 7 additions & 4 deletions codegen/src/types/composite_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// along with subxt. If not, see <http://www.gnu.org/licenses/>.

use super::{
Derives,
Field,
GeneratedTypeDerives,
TypeDefParameters,
TypeGenerator,
TypeParameter,
Expand All @@ -29,6 +29,8 @@ use quote::{
quote,
};
use scale_info::{
form::PortableForm,
Type,
TypeDef,
TypeDefPrimitive,
};
Expand All @@ -52,14 +54,15 @@ pub struct CompositeDef {
impl CompositeDef {
/// Construct a definition which will generate code for a standalone `struct`.
pub fn struct_def(
ty: &Type<PortableForm>,
ident: &str,
type_params: TypeDefParameters,
fields_def: CompositeDefFields,
field_visibility: Option<syn::Visibility>,
type_gen: &TypeGenerator,
docs: &[String],
) -> Self {
let mut derives = type_gen.derives().clone();
let mut derives = type_gen.type_derives(ty);
let fields: Vec<_> = fields_def.field_types().collect();

if fields.len() == 1 {
Expand All @@ -82,7 +85,7 @@ impl CompositeDef {
| TypeDefPrimitive::U128
)
) {
derives.push_codec_compact_as()
derives.insert_codec_compact_as()
}
}
}
Expand Down Expand Up @@ -165,7 +168,7 @@ impl quote::ToTokens for CompositeDef {
pub enum CompositeDefKind {
/// Composite type comprising a Rust `struct`.
Struct {
derives: GeneratedTypeDerives,
derives: Derives,
type_params: TypeDefParameters,
field_visibility: Option<syn::Visibility>,
},
Expand Down
93 changes: 76 additions & 17 deletions codegen/src/types/derives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,107 @@
use syn::{
parse_quote,
punctuated::Punctuated,
Path,
};

use std::collections::{
HashMap,
HashSet,
};

#[derive(Debug, Default, Clone)]
pub struct DerivesRegistry {
default_derives: Derives,
specific_type_derives: HashMap<syn::TypePath, Derives>,
}

impl DerivesRegistry {
/// Insert derives to be applied to all generated types.
pub fn extend_for_all(&mut self, derives: impl IntoIterator<Item = syn::Path>) {
self.default_derives.derives.extend(derives)
}

/// Insert derives to be applied to a specific generated type.
pub fn extend_for_type(
&mut self,
ty: syn::TypePath,
derives: impl IntoIterator<Item = syn::Path>,
) {
let type_derives = self
.specific_type_derives
.entry(ty)
.or_insert_with(Derives::default);
type_derives.derives.extend(derives)
}

/// Returns the derives to be applied to all generated types.
pub fn default_derives(&self) -> &Derives {
&self.default_derives
}

/// Resolve the derives for a generated type. Includes:
/// - The default derives for all types e.g. `scale::Encode, scale::Decode`
/// - Any user-defined derives for all types via `generated_type_derives`
/// - Any user-defined derives for this specific type
pub fn resolve(&self, ty: &syn::TypePath) -> Derives {
let mut defaults = self.default_derives.derives.clone();
if let Some(specific) = self.specific_type_derives.get(ty) {
defaults.extend(specific.derives.iter().cloned());
}
Derives { derives: defaults }
}
}

#[derive(Debug, Clone)]
pub struct GeneratedTypeDerives {
derives: Punctuated<syn::Path, syn::Token![,]>,
pub struct Derives {
derives: HashSet<syn::Path>,
}

impl GeneratedTypeDerives {
pub fn new(derives: Punctuated<syn::Path, syn::Token!(,)>) -> Self {
impl FromIterator<syn::Path> for Derives {
fn from_iter<T: IntoIterator<Item = Path>>(iter: T) -> Self {
let derives = iter.into_iter().collect();
Self { derives }
}
}

impl Derives {
/// Add `::subxt::codec::CompactAs` to the derives.
pub fn push_codec_compact_as(&mut self) {
self.derives.push(parse_quote!(::subxt::codec::CompactAs));
pub fn insert_codec_compact_as(&mut self) {
self.insert(parse_quote!(::subxt::codec::CompactAs));
}

pub fn append(&mut self, derives: impl Iterator<Item = syn::Path>) {
for derive in derives {
self.derives.push(derive)
self.insert(derive)
}
}

pub fn push(&mut self, derive: syn::Path) {
self.derives.push(derive);
pub fn insert(&mut self, derive: syn::Path) {
self.derives.insert(derive);
}
}

impl Default for GeneratedTypeDerives {
impl Default for Derives {
fn default() -> Self {
let mut derives = Punctuated::new();
derives.push(syn::parse_quote!(::subxt::codec::Encode));
derives.push(syn::parse_quote!(::subxt::codec::Decode));
derives.push(syn::parse_quote!(Debug));
Self::new(derives)
let mut derives = HashSet::new();
derives.insert(syn::parse_quote!(::subxt::codec::Encode));
derives.insert(syn::parse_quote!(::subxt::codec::Decode));
derives.insert(syn::parse_quote!(Debug));
Self { derives }
}
}

impl quote::ToTokens for GeneratedTypeDerives {
impl quote::ToTokens for Derives {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
if !self.derives.is_empty() {
let derives = &self.derives;
let mut sorted = self.derives.iter().cloned().collect::<Vec<_>>();
sorted.sort_by(|a, b| {
quote::quote!(#a)
.to_string()
.cmp(&quote::quote!(#b).to_string())
});
let derives: Punctuated<syn::Path, syn::Token![,]> =
sorted.iter().cloned().collect();
tokens.extend(quote::quote! {
#[derive(#derives)]
})
Expand Down
Loading

0 comments on commit 24317b4

Please sign in to comment.