diff --git a/src/expand.rs b/src/expand.rs index fe3f530..85c39c6 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -8,8 +8,8 @@ use syn::punctuated::Punctuated; use syn::visit_mut::{self, VisitMut}; use syn::{ parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericParam, Generics, Ident, - ImplItem, Lifetime, Pat, PatIdent, Receiver, ReturnType, Signature, Stmt, Token, TraitItem, - Type, TypeParamBound, TypePath, WhereClause, + ImplItem, Lifetime, LifetimeDef, Pat, PatIdent, Receiver, ReturnType, Signature, Stmt, Token, + TraitItem, Type, TypeParamBound, TypePath, WhereClause, }; impl ToTokens for Item { @@ -34,17 +34,18 @@ enum Context<'a> { } impl Context<'_> { - fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator { + fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator { let generics = match self { Context::Trait { generics, .. } => generics, Context::Impl { impl_generics, .. } => impl_generics, }; - generics.params.iter().filter(move |param| { + generics.params.iter().filter_map(move |param| { if let GenericParam::Lifetime(param) = param { - used.contains(¶m.lifetime) - } else { - false + if used.contains(¶m.lifetime) { + return Some(param); + } } + None }) } } @@ -178,12 +179,7 @@ fn transform_sig( } } - for param in sig - .generics - .params - .iter() - .chain(context.lifetimes(&lifetimes.explicit)) - { + for param in &sig.generics.params { match param { GenericParam::Type(param) => { let param = ¶m.ident; @@ -203,6 +199,14 @@ fn transform_sig( } } + for param in context.lifetimes(&lifetimes.explicit) { + let param = ¶m.lifetime; + let span = param.span(); + where_clause_or_default(&mut sig.generics.where_clause) + .predicates + .push(parse_quote_spanned!(span=> #param: 'async_trait)); + } + if sig.generics.lt_token.is_none() { sig.generics.lt_token = Some(Token![<](sig.ident.span())); }