Skip to content

Commit

Permalink
Add skip_type_params attribute (#96)
Browse files Browse the repository at this point in the history
* Add new top-level attribute `scale_info(bounds(T: SomeTrait + OtherTrait))`

* Fmt

* cleanup

* Skip type params prototype

* Fmt

* Clippy

* Satisfy clippy

* Fix skip type params

* Fix UI test

* Add some more tests

* Add failing test for type params with default values

* Fix for type params with defaults, compare on ident

* WIP use bounds attribute for skipping type params

* Add required 'static bounds to all type params

* Fmt

* Satisfy clippy

* Add UI test for skipping bounds

* WIP docs

* Revert "Use bounds attribute for skipping type params"

This reverts commit 4cdf0c5.

* WIP dual attribute parsing

* Use new attribute parsing

* Fix up attribute parsing

* Fmt

* Reorder impls

* Refactor attribute handling

* Fmt

* Add docs to attributes

* Add `'static` bounds for all type params, add some ui tests

* Check for duplicate attributes

* Add helpful error message for type params in not in cuatom bounds and not skipped

* Improve error message where a type param is missing from the bounds, and not skipped

* Fix test and fmt

* Refactor and validate missing skip type params

* Update validation UI test

* Error message formatting

* Fix compilation after merge

* Add TypeParameter struct and optional type

* Add ui test for skipping type params

* Add example to named_type_params macro

* Type hint for named_type_params with MetaForm

* Add bounds attribute docs

* Add some docs attribute usage

* Fmt

Co-authored-by: David Palm <[email protected]>
  • Loading branch information
ascjones and dvdplm authored Jun 29, 2021
1 parent 43a1fad commit e7975d2
Show file tree
Hide file tree
Showing 18 changed files with 785 additions and 185 deletions.
202 changes: 202 additions & 0 deletions derive/src/attr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright 2019-2021 Parity Technologies (UK) Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use syn::{
parse::{
Parse,
ParseBuffer,
},
punctuated::Punctuated,
spanned::Spanned,
Token,
};

const SCALE_INFO: &str = "scale_info";

mod keywords {
syn::custom_keyword!(scale_info);
syn::custom_keyword!(bounds);
syn::custom_keyword!(skip_type_params);
}

/// Parsed and validated set of `#[scale_info(...)]` attributes for an item.
pub struct Attributes {
bounds: Option<BoundsAttr>,
skip_type_params: Option<SkipTypeParamsAttr>,
}

impl Attributes {
/// Extract out `#[scale_info(...)]` attributes from an item.
pub fn from_ast(item: &syn::DeriveInput) -> syn::Result<Self> {
let mut bounds = None;
let mut skip_type_params = None;

let attributes_parser = |input: &ParseBuffer| {
let attrs: Punctuated<ScaleInfoAttr, Token![,]> =
input.parse_terminated(ScaleInfoAttr::parse)?;
Ok(attrs)
};

for attr in &item.attrs {
if !attr.path.is_ident(SCALE_INFO) {
continue
}
let scale_info_attrs = attr.parse_args_with(attributes_parser)?;

for scale_info_attr in scale_info_attrs {
// check for duplicates
match scale_info_attr {
ScaleInfoAttr::Bounds(parsed_bounds) => {
if bounds.is_some() {
return Err(syn::Error::new(
attr.span(),
"Duplicate `bounds` attributes",
))
}
bounds = Some(parsed_bounds);
}
ScaleInfoAttr::SkipTypeParams(parsed_skip_type_params) => {
if skip_type_params.is_some() {
return Err(syn::Error::new(
attr.span(),
"Duplicate `skip_type_params` attributes",
))
}
skip_type_params = Some(parsed_skip_type_params);
}
}
}
}

// validate type params which do not appear in custom bounds but are not skipped.
if let Some(ref bounds) = bounds {
for type_param in item.generics.type_params() {
if !bounds.contains_type_param(type_param) {
let type_param_skipped = skip_type_params
.as_ref()
.map(|skip| skip.skip(type_param))
.unwrap_or(false);
if !type_param_skipped {
let msg = format!(
"Type parameter requires a `TypeInfo` bound, so either: \n \
- add it to `#[scale_info(bounds({}: TypeInfo))]` \n \
- skip it with `#[scale_info(skip_type_params({}))]`",
type_param.ident, type_param.ident
);
return Err(syn::Error::new(type_param.span(), msg))
}
}
}
}

Ok(Self {
bounds,
skip_type_params,
})
}

/// Get the `#[scale_info(bounds(...))]` attribute, if present.
pub fn bounds(&self) -> Option<&BoundsAttr> {
self.bounds.as_ref()
}

/// Get the `#[scale_info(skip_type_params(...))]` attribute, if present.
pub fn skip_type_params(&self) -> Option<&SkipTypeParamsAttr> {
self.skip_type_params.as_ref()
}
}

/// Parsed representation of the `#[scale_info(bounds(...))]` attribute.
#[derive(Clone)]
pub struct BoundsAttr {
predicates: Punctuated<syn::WherePredicate, Token![,]>,
}

impl Parse for BoundsAttr {
fn parse(input: &ParseBuffer) -> syn::Result<Self> {
input.parse::<keywords::bounds>()?;
let content;
syn::parenthesized!(content in input);
let predicates = content.parse_terminated(syn::WherePredicate::parse)?;
Ok(Self { predicates })
}
}

impl BoundsAttr {
/// Add the predicates defined in this attribute to the given `where` clause.
pub fn extend_where_clause(&self, where_clause: &mut syn::WhereClause) {
where_clause.predicates.extend(self.predicates.clone());
}

/// Returns true if the given type parameter appears in the custom bounds attribute.
pub fn contains_type_param(&self, type_param: &syn::TypeParam) -> bool {
self.predicates.iter().any(|p| {
if let syn::WherePredicate::Type(ty) = p {
if let syn::Type::Path(ref path) = ty.bounded_ty {
path.path.get_ident() == Some(&type_param.ident)
} else {
false
}
} else {
false
}
})
}
}

/// Parsed representation of the `#[scale_info(skip_type_params(...))]` attribute.
#[derive(Clone)]
pub struct SkipTypeParamsAttr {
type_params: Punctuated<syn::TypeParam, Token![,]>,
}

impl Parse for SkipTypeParamsAttr {
fn parse(input: &ParseBuffer) -> syn::Result<Self> {
input.parse::<keywords::skip_type_params>()?;
let content;
syn::parenthesized!(content in input);
let type_params = content.parse_terminated(syn::TypeParam::parse)?;
Ok(Self { type_params })
}
}

impl SkipTypeParamsAttr {
/// Returns `true` if the given type parameter should be skipped.
pub fn skip(&self, type_param: &syn::TypeParam) -> bool {
self.type_params
.iter()
.any(|tp| tp.ident == type_param.ident)
}
}

/// Parsed representation of one of the `#[scale_info(..)]` attributes.
pub enum ScaleInfoAttr {
Bounds(BoundsAttr),
SkipTypeParams(SkipTypeParamsAttr),
}

impl Parse for ScaleInfoAttr {
fn parse(input: &ParseBuffer) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(keywords::bounds) {
let bounds = input.parse()?;
Ok(Self::Bounds(bounds))
} else if lookahead.peek(keywords::skip_type_params) {
let skip_type_params = input.parse()?;
Ok(Self::SkipTypeParams(skip_type_params))
} else {
Err(input.error("Expected either `bounds` or `skip_type_params`"))
}
}
}
47 changes: 20 additions & 27 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,10 @@
extern crate alloc;
extern crate proc_macro;

mod attr;
mod trait_bounds;
mod utils;

use alloc::{
string::{
String,
ToString,
},
vec::Vec,
};
use proc_macro::TokenStream;
use proc_macro2::{
Span,
Expand Down Expand Up @@ -66,36 +60,35 @@ fn generate(input: TokenStream2) -> Result<TokenStream2> {
}

fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
let mut ast: DeriveInput = syn::parse2(input.clone())?;
let ast: DeriveInput = syn::parse2(input.clone())?;

utils::check_attributes(&ast)?;
let attrs = attr::Attributes::from_ast(&ast)?;

let scale_info = crate_name_ident("scale-info")?;
let parity_scale_codec = crate_name_ident("parity-scale-codec")?;

let ident = &ast.ident;

let where_clause = if let Some(custom_bounds) = utils::custom_trait_bounds(&ast.attrs)
{
let where_clause = ast.generics.make_where_clause();
where_clause.predicates.extend(custom_bounds);
where_clause.clone()
} else {
trait_bounds::make_where_clause(
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?
};
let where_clause = trait_bounds::make_where_clause(
&attrs,
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?;

let (impl_generics, ty_generics, _) = ast.generics.split_for_impl();

let generic_type_ids = ast.generics.type_params().map(|ty| {
let ty_ident = &ty.ident;
let type_params = ast.generics.type_params().map(|tp| {
let ty_ident = &tp.ident;
let ty = if attrs.skip_type_params().map_or(true, |skip| !skip.skip(tp)) {
quote! { Some(:: #scale_info ::meta_type::<#ty_ident>()) }
} else {
quote! { None }
};
quote! {
:: #scale_info ::meta_type::<#ty_ident>()
:: #scale_info ::TypeParameter::new(::core::stringify!(#ty_ident), #ty)
}
});

Expand All @@ -112,7 +105,7 @@ fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
fn type_info() -> :: #scale_info ::Type {
:: #scale_info ::Type::builder()
.path(:: #scale_info ::Path::new(::core::stringify!(#ident), ::core::module_path!()))
.type_params(:: #scale_info ::prelude::vec![ #( #generic_type_ids ),* ])
.type_params(:: #scale_info ::prelude::vec![ #( #type_params ),* ])
#docs
.#build_type
}
Expand Down
42 changes: 38 additions & 4 deletions derive/src/trait_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,26 @@ use syn::{
WhereClause,
};

use crate::utils;
use crate::{
attr::Attributes,
utils,
};

/// Generates a where clause for a `TypeInfo` impl, adding `TypeInfo + 'static` bounds to all
/// relevant generic types including associated types (e.g. `T::A: TypeInfo`), correctly dealing
/// with self-referential types.
///
/// # Effect of attributes
///
/// `#[scale_info(skip_type_params(..))]`
///
/// Will not add `TypeInfo` bounds for any type parameters skipped via this attribute.
///
/// `#[scale_info(bounds(..))]`
///
/// Replaces *all* auto-generated trait bounds with the user-defined ones.
pub fn make_where_clause<'a>(
attrs: &'a Attributes,
input_ident: &'a Ident,
generics: &'a Generics,
data: &'a syn::Data,
Expand All @@ -47,14 +61,29 @@ pub fn make_where_clause<'a>(
predicates: Punctuated::new(),
}
});

// Use custom bounds as where clause.
if let Some(custom_bounds) = attrs.bounds() {
custom_bounds.extend_where_clause(&mut where_clause);

// `'static` lifetime bounds are always required for type parameters, because of the
// requirement on `std::any::TypeId::of` for any field type constructor.
for type_param in generics.type_params() {
let ident = &type_param.ident;
where_clause.predicates.push(parse_quote!(#ident: 'static))
}

return Ok(where_clause)
}

for lifetime in generics.lifetimes() {
where_clause
.predicates
.push(parse_quote!(#lifetime: 'static))
}

let type_params = generics.type_params();
let ty_params_ids = type_params
let ty_params_ids = generics
.type_params()
.map(|type_param| type_param.ident.clone())
.collect::<Vec<Ident>>();

Expand All @@ -79,7 +108,12 @@ pub fn make_where_clause<'a>(
generics.type_params().into_iter().for_each(|type_param| {
let ident = type_param.ident.clone();
let mut bounds = type_param.bounds.clone();
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
if attrs
.skip_type_params()
.map_or(true, |skip| !skip.skip(type_param))
{
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
}
bounds.push(parse_quote!('static));
where_clause
.predicates
Expand Down
Loading

0 comments on commit e7975d2

Please sign in to comment.