Skip to content

Commit

Permalink
derive refactor to avoid clippy warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-r-earp committed Mar 10, 2024
1 parent b5b0d77 commit 4eeb6b8
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions autograph_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,15 @@ impl Layers {
fn iter(&self, method: Ident) -> TokenStream2 {
match self {
Self::Struct(layers) => {
quote! {
::std::iter::empty()
#(.chain(self.#layers.#method()))*
if let Some((first, layers)) = layers.split_first() {
quote! {
self.#first.#method()
#(.chain(self.#layers.#method()))*
}
} else {
quote! {
::std::iter::empty()
}
}
}
Self::Enum(layers) => {
Expand All @@ -168,11 +174,17 @@ impl Layers {
fn try_iter_mut(&self, method: Ident) -> TokenStream2 {
match self {
Self::Struct(layers) => {
quote! {
Ok(
::std::iter::empty()
#(.chain(self.#layers.#method()?))*
)
if let Some((first, layers)) = layers.split_first() {
quote! {
Ok(
self.#first.#method()?
#(.chain(self.#layers.#method()?))*
)
}
} else {
quote! {
Ok(::std::iter::empty())
}
}
}
Self::Enum(layers) => {
Expand Down Expand Up @@ -287,7 +299,7 @@ fn layer_impl(input: TokenStream2) -> Result<TokenStream2> {
fn parameter_iter(&self) -> impl ::std::iter::Iterator<Item=#autograph::learn::neural_network::autograd::ParameterD> + '_ {
#parameter_iter
}
fn make_parameter_iter_mut(&mut self) -> #autograph::anyhow::Result<impl ::std::iter::Iterator<Item= #autograph::learn::neural_network::autograd::ParameterViewMutD> + '_> {
fn make_parameter_iter_mut(&mut self) -> #autograph::anyhow::Result<impl ::std::iter::Iterator<Item=#autograph::learn::neural_network::autograd::ParameterViewMutD> + '_> {
#make_parameter_iter_mut
}
fn set_training(&mut self, training: bool) -> #autograph::anyhow::Result<()> {
Expand Down Expand Up @@ -326,8 +338,15 @@ fn forward_impl(input: TokenStream2) -> Result<TokenStream2> {

let forward = match layers {
Layers::Struct(layers) => {
quote! {
Ok(input #(.forward(&self.#layers)?)*)
if let Some((last, layers)) = layers.split_last() {
quote! {
input #(.forward(&self.#layers)?)*
.forward(&self.#last)
}
} else {
quote! {
Ok(())
}
}
}
Layers::Enum(layers) => {
Expand Down

0 comments on commit 4eeb6b8

Please sign in to comment.