-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Get the full run of a recurrent cell using Lux #282
Comments
Code example:
I would expect the third element of the return value of the first call to be different than the return value of the second call, but it's the same, meaning the values were individually passed to the GRUCell and not in sequence. |
That is not really true. In your example, you are just batching the input differently, so it is effectively a 3 vs 6 independent inputs in each case. What you want to do is: gruit(reshape(1:3, 1, :, 1), ps)
gruit(reshape(3:3, 1, :, 1), ps) But you are right, currently there is no way to access the intermediate values. However, it should be quite easy to patch that. |
Hi Avik! That makes sense :) The way with the colon in the last dimension was just how the Flux API works. Just accumulating the state into a vector of course is not AD-friendly. Flux seems to have code to make this work: |
Once #287 gets merged, |
Super cool, thanks so much!! :) |
Is there a way to make this compatible with Chains? Like:
|
This is a special case for Parallel (needs to be implemented). Current behavior:
Updated behavior:
For your case, you will need: encoder = Lux.Chain(Lux.Recurrence(Lux.LSTMCell(data_dims => context_size); return_sequence=true),
Lux.WrappedFunction(Tuple),
Lux.Parallel(nothing, Lux.Dense(context_size => context_size))) |
Thanks! :) Sadly this just passes the first element of the tuple to the Edit: Oh, that's what you said, I get it. |
Are you sure this should be a |
You are right, |
A simple implementation: struct BroadcastLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)}
layers::T
end
function BroadcastLayer(layers...)
for l in layers
if !iszero(statelength(l))
throw(ArgumentError("Stateful layer `$l` are not supported for `BroadcastLayer`."))
end
end
names = ntuple(i -> Symbol("layer_$i"), length(layers))
return BroadcastLayer(NamedTuple{names}(layers))
end
BroadcastLayer(; kwargs...) = BroadcastLayer(connection, (; kwargs...))
function (m::BroadcastLayer)(x, ps, st::NamedTuple{names}) where {names}
results = (first ∘ Lux.apply).(values(m.layers), x, values(ps), values(st))
return results, st
end
Base.keys(m::BroadcastLayer) = Base.keys(getfield(m, :layers)) |
This doesn't work for me yet, it throws Using it like:
What is it supposed to do with multiple layers? I'm not sure I'm reading the Thanks for your support :) |
Ah a LuxCore breaking release is needed to fix the default parameter setting. For now define |
the issue persists with this. did I do it right?
|
|
oof. thanks |
Sorry for more annoying. It seems like the issue actually comes from |
Hi! It's not clear to me how to mimic Flux's behaviour of calling a recurrent cell with a list of values like this:
What I want is the full run of the recurrent cell, i.e. a list of all output values.
It also seems not entirely straight-forward to just accumulate them myself because in the trivial version I'm modifying array state, which can't be reverse-diffed.
As far as I can read the source code, the internal Lux code "throws away" the intermediate results when updating a Recurrent or a StatefulRecurrentCell. Is that correct or am I just missing something?
Thanks :)
The text was updated successfully, but these errors were encountered: