Skip to content

Commit

Permalink
WIP use bounds attribute for skipping type params
Browse files Browse the repository at this point in the history
  • Loading branch information
ascjones committed Jun 14, 2021
1 parent 3678114 commit 4cdf0c5
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 67 deletions.
26 changes: 10 additions & 16 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,34 +63,30 @@ fn generate_type(input: TokenStream2) -> Result<TokenStream2> {

utils::check_attributes(&ast)?;

utils::check_attributes(&ast)?;

let scale_info = crate_name_ident("scale-info")?;
let parity_scale_codec = crate_name_ident("parity-scale-codec")?;

let ident = &ast.ident;

let type_params: Vec<_> =
if let Some(skip_type_params) = utils::skipped_type_params(&ast.attrs) {
ast.generics
.type_params()
.filter(|tp| !skip_type_params.iter().any(|skip| skip.ident == tp.ident))
.cloned()
.collect()
} else {
ast.generics.type_params().into_iter().cloned().collect()
};
let mut type_params = ast.generics.type_params().into_iter().cloned().collect::<Vec<_>>();

let where_clause = if let Some(custom_bounds) = utils::custom_trait_bounds(&ast.attrs)
{
// remove type params which are not part of the custom where clause
let bound_type_idents = custom_bounds.bound_type_path_idents();
type_params.retain(|tp|
bound_type_idents.iter().any(|id| id == &tp.ident)
);

// todo: [AJ] add 'static bounds to skipped type params??? why do we need it?

let where_clause = ast.generics.make_where_clause();
where_clause.predicates.extend(custom_bounds);
where_clause.predicates.extend(custom_bounds.bounds());
where_clause.clone()
} else {
trait_bounds::make_where_clause(
ident,
&ast.generics,
&type_params,
&ast.data,
&scale_info,
&parity_scale_codec,
Expand All @@ -99,8 +95,6 @@ fn generate_type(input: TokenStream2) -> Result<TokenStream2> {

let (impl_generics, ty_generics, _) = ast.generics.split_for_impl();

let (impl_generics, ty_generics, _) = ast.generics.split_for_impl();

let type_params_meta_types = type_params.iter().map(|ty| {
let ty_ident = &ty.ident;
quote! {
Expand Down
8 changes: 1 addition & 7 deletions derive/src/trait_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use syn::{
Generics,
Result,
Type,
TypeParam,
TypePath,
WhereClause,
};
Expand All @@ -35,12 +34,9 @@ use crate::utils;
/// Generates a where clause for a `TypeInfo` impl, adding `TypeInfo + 'static` bounds to all
/// relevant generic types including associated types (e.g. `T::A: TypeInfo`), correctly dealing
/// with self-referential types.
///
/// Ignores any type parameters not included in `type_params`.
pub fn make_where_clause<'a>(
input_ident: &'a Ident,
generics: &'a Generics,
type_params: &[TypeParam],
data: &'a syn::Data,
scale_info: &Ident,
parity_scale_codec: &Ident,
Expand Down Expand Up @@ -85,9 +81,7 @@ pub fn make_where_clause<'a>(
generics.type_params().into_iter().for_each(|type_param| {
let ident = type_param.ident.clone();
let mut bounds = type_param.bounds.clone();
if type_params.iter().any(|tp| *tp == *type_param) {
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
}
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
bounds.push(parse_quote!('static));
where_clause
.predicates
Expand Down
66 changes: 28 additions & 38 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub fn get_doc_literals(attrs: &[syn::Attribute]) -> Vec<syn::Lit> {
pub type TraitBounds = Punctuated<syn::WherePredicate, token::Comma>;

/// Parse `name(T: Bound, N: Bound)` as a custom trait bound.
struct CustomTraitBound<N> {
pub struct CustomTraitBound<N> {
_name: N,
_paren_token: token::Paren,
bounds: TraitBounds,
Expand All @@ -83,49 +83,39 @@ impl<N: Parse> Parse for CustomTraitBound<N> {
}
}

syn::custom_keyword!(bounds);

/// Look for a `#[scale_info(bounds(…))]`in the given attributes.
///
/// If found, use the given trait bounds when deriving the `TypeInfo` trait.
pub fn custom_trait_bounds(attrs: &[Attribute]) -> Option<TraitBounds> {
scale_info_meta_item(attrs.iter(), |meta: CustomTraitBound<bounds>| {
Some(meta.bounds)
})
}

/// Trait bounds.
pub type TypeParams = Punctuated<syn::TypeParam, token::Comma>;

/// Parse `name(T, N)` as a custom trait bound.
struct SkipTypeParams<N> {
_name: N,
_paren_token: token::Paren,
params: TypeParams,
}
impl<N: Parse> CustomTraitBound<N> {
/// Returns all bound types which consist of a single non-parameterized path, which includes
/// the generic type parameters e.g. the `T` in `T: TypeInfo`.
pub fn bound_type_path_idents(&self) -> Vec<syn::Ident> {
self.bounds
.iter()
.filter_map(|bound| {
if let syn::WherePredicate::Type(ty) = bound {
if let syn::Type::Path(ref path) = ty.bounded_ty {
path.path.get_ident().cloned()
} else {
None
}
} else {
None
}
})
.collect()
}

impl<N: Parse> Parse for SkipTypeParams<N> {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let content;
let _name = input.parse()?;
let _paren_token = syn::parenthesized!(content in input);
let params = content.parse_terminated(syn::TypeParam::parse)?;
Ok(Self {
_name,
_paren_token,
params,
})
pub fn bounds(self) -> TraitBounds {
self.bounds
}
}

syn::custom_keyword!(skip_type_params);
syn::custom_keyword!(bounds);

/// Look for a `#[scale_info(skip_type_params(…))]`in the given attributes.
/// Look for a `#[scale_info(bounds(…))]`in the given attributes.
///
/// If found, do not register the given type params or require `TypeInfo` bounds for them.
pub fn skipped_type_params(attrs: &[Attribute]) -> Option<TypeParams> {
scale_info_meta_item(attrs.iter(), |meta: SkipTypeParams<skip_type_params>| {
Some(meta.params)
/// If found, use the given trait bounds when deriving the `TypeInfo` trait.
pub fn custom_trait_bounds(attrs: &[Attribute]) -> Option<CustomTraitBound<bounds>> {
scale_info_meta_item(attrs.iter(), |meta: CustomTraitBound<bounds>| {
Some(meta)
})
}

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ pub trait TypeInfo {
/// This is used to uniquely identify a type via [`core::any::TypeId::of`]. In most cases it
/// will just be `Self`, but can be used to unify different types which have the same encoded
/// representation e.g. reference types `Box<T>`, `&T` and `&mut T`.
type Identity: ?Sized + 'static;
type Identity: ?Sized;

/// Returns the static type identifier for `Self`.
fn type_info() -> Type;
Expand Down
10 changes: 5 additions & 5 deletions test_suite/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,14 +613,14 @@ fn doc_capture_works() {
fn skip_type_params_nested() {
#[allow(unused)]
#[derive(TypeInfo)]
#[scale_info(skip_type_params(T))]
#[scale_info(bounds(U: TypeInfo + 'static))]
struct SkipTypeParamsNested<T, U> {
a: Nested<T>,
b: U,
}

#[derive(TypeInfo)]
#[scale_info(skip_type_params(T))]
#[scale_info(bounds())]
struct Nested<T> {
marker: PhantomData<T>,
}
Expand All @@ -647,7 +647,7 @@ fn skip_type_params_nested() {
fn skip_all_type_params() {
#[allow(unused)]
#[derive(TypeInfo)]
#[scale_info(skip_type_params(T, U))]
#[scale_info(bounds())]
struct SkipAllTypeParams<T, U> {
a: PhantomData<T>,
b: PhantomData<U>,
Expand Down Expand Up @@ -682,7 +682,7 @@ fn skip_type_params_with_associated_types() {

#[allow(unused)]
#[derive(TypeInfo)]
#[scale_info(skip_type_params(T, U))]
#[scale_info(bounds(T::A: TypeInfo + 'static))]
struct SkipTypeParamsForTraitImpl<T>
where
T: Trait,
Expand Down Expand Up @@ -716,7 +716,7 @@ fn skip_type_params_with_associated_types() {
fn skip_type_params_with_defaults() {
#[allow(unused)]
#[derive(TypeInfo)]
#[scale_info(skip_type_params(T, U))]
#[scale_info(bounds())]
struct SkipAllTypeParamsWithDefaults<T = (), U = ()> {
a: PhantomData<T>,
b: PhantomData<U>,
Expand Down

0 comments on commit 4cdf0c5

Please sign in to comment.