You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The following code defines a part of a neural net that can be differentiated with Zygote 0.6.21. Upgrading to 0.6.22, the error ERROR: MethodError: no method matching JuliennedArrays.Slices(::ChainRulesCore.NoTangent, ::Int64) is thrown.
using Flux
using Flux.Zygote
using SliceMap
using JuliennedArrays
# Define a version of `Recur` that returns the whole hidden statemutable struct HiddenRecur{T,S}
cell::T
state::SendHiddenRecur(m) =HiddenRecur(m, m.state0)
function (m::HiddenRecur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return h
end# Feed arrays sequentially to recurrent neural netsmutable struct Seq{T}
chain::T
state
endSeq(chain) =Seq(chain, [0.0f0])
(l::Seq)(x) =l(l.chain.state, l.chain, x)
function (l::Seq)(::Tuple, _, x)
tuples =map(l.chain, Slices(x, True(), False()))
l.state = [Align(map(x ->dropdims(x[i], dims=2), tuples), 1) for i in1:length(l.chain.state)]
return l.state
end# Quick gradientfunctionbw(m, ip)
gs =gradient((m, x) ->sum(m(x)), m, ip)
end# Actual Code
inp =rand(Float32, 10, 50)
encoder =Seq(HiddenRecur(Flux.LSTMCell(10, 5)))
encoder(inp) # forward pass works - returns an array of 2 arrays with the hidden statebw(x ->encoder(x)[1], inp) # works in Zygote 0.6.21, errors in 0.6.22
The text was updated successfully, but these errors were encountered:
This package should really move to use ChainRulesCore.jl, but when I started I remembered that it's slightly bigger than I thought.
However, this also sounds like one more place that ChainRules types are leaking into Zygote. I didn't isolate this one, but with FluxML/Zygote.jl#1104 it seems to at least run without error:
The following code defines a part of a neural net that can be differentiated with Zygote 0.6.21. Upgrading to 0.6.22, the error
ERROR: MethodError: no method matching JuliennedArrays.Slices(::ChainRulesCore.NoTangent, ::Int64)
is thrown.The text was updated successfully, but these errors were encountered: