Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get the full run of a recurrent cell using Lux #282

Closed
linusheck opened this issue Mar 13, 2023 · 17 comments · Fixed by #287
Closed

Get the full run of a recurrent cell using Lux #282

linusheck opened this issue Mar 13, 2023 · 17 comments · Fixed by #287
Labels
enhancement New feature or request

Comments

@linusheck
Copy link

linusheck commented Mar 13, 2023

Hi! It's not clear to me how to mimic Flux's behaviour of calling a recurrent cell with a list of values like this:

m.(seq) # returns a list of 5-element vectors

m = Chain(LSTM(10, 15), Dense(15, 5))
m.(seq) # works even when we've chain recurrent layers into a larger model

What I want is the full run of the recurrent cell, i.e. a list of all output values.
It also seems not entirely straight-forward to just accumulate them myself because in the trivial version I'm modifying array state, which can't be reverse-diffed.

As far as I can read the source code, the internal Lux code "throws away" the intermediate results when updating a Recurrent or a StatefulRecurrentCell. Is that correct or am I just missing something?

Thanks :)

@linusheck
Copy link
Author

Code example:

grucell = Lux.Recurrence(Lux.GRUCell(1 => 3))
rng = Random.default_rng()
ps, st = Lux.setup(rng, grucell)
function gruit(input, ps)
	grucell(input, ps, st)
end
gruit(reshape(1:3, 1, 1, :), ps)
gruit(reshape(3:3, 1, 1, :), ps)

I would expect the third element of the return value of the first call to be different than the return value of the second call, but it's the same, meaning the values were individually passed to the GRUCell and not in sequence.

@avik-pal avik-pal added the enhancement New feature or request label Mar 14, 2023
@avik-pal
Copy link
Member

That is not really true. In your example, you are just batching the input differently, so it is effectively a 3 vs 6 independent inputs in each case. What you want to do is:

gruit(reshape(1:3, 1, :, 1), ps)
gruit(reshape(3:3, 1, :, 1), ps)

But you are right, currently there is no way to access the intermediate values. However, it should be quite easy to patch that.

@linusheck
Copy link
Author

linusheck commented Mar 14, 2023

Hi Avik! That makes sense :) The way with the colon in the last dimension was just how the Flux API works.

Just accumulating the state into a vector of course is not AD-friendly. Flux seems to have code to make this work:
# Type stable and AD-friendly helper for iterating over the last dimension of an array
As someone new to AD, this doesn't seem super easy to me :D I might look into it later, or just convert my current project to Flux, depending on what seems more reasonable

@avik-pal
Copy link
Member

Once #287 gets merged, grucell = Lux.Recurrence(Lux.GRUCell(1 => 3); return_sequence=true) will be able to do what you need

@linusheck
Copy link
Author

Super cool, thanks so much!! :)

@linusheck
Copy link
Author

Is there a way to make this compatible with Chains? Like:

encoder = Lux.Chain(Lux.Recurrence(Lux.LSTMCell(data_dims => context_size); return_sequence=true), Lux.Dense(context_size => context_size))

@avik-pal
Copy link
Member

avik-pal commented Apr 18, 2023

This is a special case for Parallel (needs to be implemented). Current behavior:

  • case 1: x is Tuple, then we pass each value to each layer
  • case 2: otherwise, pass x to each layer

Updated behavior:

  • x is tuple and length of x matches each layer pass them
  • x is tuple but length doesn't match and number of layers in parallel is 1, so pass each of x into the layer.
  • otherwise pass x into each of them

For your case, you will need:

encoder = Lux.Chain(Lux.Recurrence(Lux.LSTMCell(data_dims => context_size); return_sequence=true), 
                    Lux.WrappedFunction(Tuple),
                    Lux.Parallel(nothing, Lux.Dense(context_size => context_size)))

@linusheck
Copy link
Author

linusheck commented Apr 18, 2023

Thanks! :) Sadly this just passes the first element of the tuple to the Dense and returns a one-element tuple of the result.

Edit: Oh, that's what you said, I get it.

@avik-pal avik-pal reopened this Apr 18, 2023
@linusheck
Copy link
Author

Are you sure this should be a Parallel as well? Maybe there should be another thing for it, like a Broadcast

@avik-pal
Copy link
Member

You are right, Broadcast seems a more appropriate name.

@avik-pal
Copy link
Member

A simple implementation:

struct BroadcastLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)}
    layers::T
end

function BroadcastLayer(layers...)
    for l in layers
        if !iszero(statelength(l))
            throw(ArgumentError("Stateful layer `$l` are not supported for `BroadcastLayer`."))
        end
    end
    names = ntuple(i -> Symbol("layer_$i"), length(layers))
    return BroadcastLayer(NamedTuple{names}(layers))
end

BroadcastLayer(; kwargs...) = BroadcastLayer(connection, (; kwargs...))

function (m::BroadcastLayer)(x, ps, st::NamedTuple{names}) where {names}
    results = (first  Lux.apply).(values(m.layers), x, values(ps), values(st))
    return results, st
end

Base.keys(m::BroadcastLayer) = Base.keys(getfield(m, :layers))

@linusheck
Copy link
Author

linusheck commented Apr 18, 2023

This doesn't work for me yet, it throws type Float64 has no field weight, I think it tries to broadcast (with the dot) through my singular object in ps

Using it like:

encoder = Lux.Chain(Lux.Recurrence(Lux.LSTMCell(data_dims => context_size); return_sequence=true), BroadcastLayer(Lux.Dense(context_size => context_size)))

What is it supposed to do with multiple layers? I'm not sure I'm reading the results = ... line correctly

Thanks for your support :)

@avik-pal
Copy link
Member

Ah a LuxCore breaking release is needed to fix the default parameter setting. For now define initialparameters(rng, l::BroadcastLayer) = NamedTuple{keys(l.layers)}(initialparameters.((rng, ), values(l.layers))

@linusheck
Copy link
Author

the issue persists with this. did I do it right?

begin
	struct BroadcastLayer{T <: NamedTuple} <: 	LuxCore.AbstractExplicitContainerLayer{(:layers,)}
    layers::T
	end
	function BroadcastLayer(layers...)
	    for l in layers
	        if !iszero(LuxCore.statelength(l))
	            throw(ArgumentError("Stateful layer `$l` are not supported for `BroadcastLayer`."))
	        end
	    end
	    names = ntuple(i -> Symbol("layer_$i"), length(layers))
	    return BroadcastLayer(NamedTuple{names}(layers))
	end
	
	BroadcastLayer(; kwargs...) = BroadcastLayer(connection, (; kwargs...))
	
	function (m::BroadcastLayer)(x, ps, st::NamedTuple{names}) where {names}
	    results = (first ∘ Lux.apply).(values(m.layers), x, values(ps), values(st))
	    return results, st
	end

	initialparameters(rng, l::BroadcastLayer) = NamedTuple{keys(l.layers)}(initialparameters.((rng, ), values(l.layers)))
	
	Base.keys(m::BroadcastLayer) = Base.keys(getfield(m, :layers))
end

@avik-pal
Copy link
Member

LuxCore.initial... you need to add a method for the predefined function.

@linusheck
Copy link
Author

oof. thanks

@linusheck
Copy link
Author

Sorry for more annoying. It seems like the issue actually comes from values not working for ComponentArrays as expected, the values function just gives you back the component array.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants