diff --git a/autograph_derive/src/lib.rs b/autograph_derive/src/lib.rs index 02d1e2b2..704a3d29 100644 --- a/autograph_derive/src/lib.rs +++ b/autograph_derive/src/lib.rs @@ -146,9 +146,15 @@ impl Layers { fn iter(&self, method: Ident) -> TokenStream2 { match self { Self::Struct(layers) => { - quote! { - ::std::iter::empty() - #(.chain(self.#layers.#method()))* + if let Some((first, layers)) = layers.split_first() { + quote! { + self.#first.#method() + #(.chain(self.#layers.#method()))* + } + } else { + quote! { + ::std::iter::empty() + } } } Self::Enum(layers) => { @@ -168,11 +174,17 @@ impl Layers { fn try_iter_mut(&self, method: Ident) -> TokenStream2 { match self { Self::Struct(layers) => { - quote! { - Ok( - ::std::iter::empty() - #(.chain(self.#layers.#method()?))* - ) + if let Some((first, layers)) = layers.split_first() { + quote! { + Ok( + self.#first.#method()? + #(.chain(self.#layers.#method()?))* + ) + } + } else { + quote! { + Ok(::std::iter::empty()) + } } } Self::Enum(layers) => { @@ -287,7 +299,7 @@ fn layer_impl(input: TokenStream2) -> Result { fn parameter_iter(&self) -> impl ::std::iter::Iterator + '_ { #parameter_iter } - fn make_parameter_iter_mut(&mut self) -> #autograph::anyhow::Result + '_> { + fn make_parameter_iter_mut(&mut self) -> #autograph::anyhow::Result + '_> { #make_parameter_iter_mut } fn set_training(&mut self, training: bool) -> #autograph::anyhow::Result<()> { @@ -326,8 +338,15 @@ fn forward_impl(input: TokenStream2) -> Result { let forward = match layers { Layers::Struct(layers) => { - quote! { - Ok(input #(.forward(&self.#layers)?)*) + if let Some((last, layers)) = layers.split_last() { + quote! { + input #(.forward(&self.#layers)?)* + .forward(&self.#last) + } + } else { + quote! { + Ok(()) + } } } Layers::Enum(layers) => {