Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise a subset of parameters #35

Closed
mcabbott opened this issue Jan 28, 2022 · 7 comments · Fixed by #36
Closed

Optimise a subset of parameters #35

mcabbott opened this issue Jan 28, 2022 · 7 comments · Fixed by #36

Comments

@mcabbott
Copy link
Member

mcabbott commented Jan 28, 2022

Flux's trainable works like this:

julia> Flux.trainable(BatchNorm(2, relu))  # this avoids half the parameters
(Float32[0.0, 0.0], Float32[1.0, 1.0])

julia> Functors.children(BatchNorm(2, relu))   # this sees them all, for |> gpu= NNlib.relu, β = Float32[0.0, 0.0], γ = Float32[1.0, 1.0], μ = Float32[0.0, 0.0], σ² = Float32[1.0, 1.0], ϵ = 1.0f-5, momentum = 0.1f0, affine = true, track_stats = true, active = nothing, chs = 2)

This doesn't seem great, it relies on objectid to know which parameters those really are. So this:

function _trainable_walk(f, x)
  func, re = functor(x)
  nb = trainable(x)
  re(map(c -> c in nb ? f(c) : c, func))
end

will not work correctly for say β === SA[0.0, 0.0] === μ.

How should it work?

  • One idea would be to clone the @functor macro to have @trainable BatchNorm (β, γ)? In fact this case is even worse, it checks a value here but we could probably move affine into the type.

  • Another idea would be just to have trainable(:: BatchNorm) = (:β, :γ) the symbols. That's much easier to write and perhaps less mysterious. Might be slower, do we care? Or might not be, if the symbols are known from the type. It would be easy here to allow Flux-style tuples as a fallback, detecting NTuple{Symbol} etc, making it easier to have both old- and new-style at once.

This would be used during setup, just one pass. After that, the tree of optimiser states should tell you whether or not to update a given array, so update need never call this.

What might call it more often is destructure, which I think we want to walk only the trainable parameters, and will sometimes be called in a loop.

@darsnack
Copy link
Member

_trainable_walk seems correct assuming fmap's behavior for shared parameters. I feel like this is more evidence that fmap should not be caching nodes for the sake of sharing.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 28, 2022

It should work fine with Arrays. With SArrays... right now fmap will conclude that β and μ are shared, but I think map(c -> ... runs on children before it's checked that... and then I don't quite know, I guess it may depend on the order in the struct?

But anyway the scope here is narrower. I think you are agreeing that this trainable doesn't capture the right information. So what should replace it?

@darsnack
Copy link
Member

You're right, I completely misread the issue. We definitely agree on what's wrong here.

Why not just have trainable return a NamedTuple instead? Or the NTuple{Symbol} is okay too.

@darsnack
Copy link
Member

This would be used during setup, just one pass. After that, the tree of optimiser states should tell you whether or not to update a given array, so update need never call this.

That's an option though I'd suggest we don't introduce state we can get for free. What I mean by this is that if trainable returns the correct thing, then fmap(..., walk = walk_subset(trainable)) should work. We don't use fmap now, but eventually I think we should get it to the place where we can.

@mcabbott
Copy link
Member Author

Why not just have trainable return a NamedTuple instead?

We could do this. Flux's rules can be updated to do that without breaking anything there.

I am picturing that, in a not distant version of Flux, it should depend on Optimisers and provide methods which work here, while still being usable in its old way.

If user code returns a Tuple, we can convert it (using objectid) and print a warning.

@mcabbott
Copy link
Member Author

don't introduce state we can get for free

You can write setup in one line with fmap and some trainable-walk. But whether this is better than trusting the tree of optimiser states I don't know. It's an object we already make and pass along, and it must be right.

I guess that they start to differ more once we handle tied parameters. I picture the tree of states also having (at its root) some lenses or something telling us about what transformations to perform before starting; these things are figured out ounce during setup. They could also be done every time, but I think that would need another pass over the model before the update one.

@ToucheSir
Copy link
Member

But setup is the function that constructs that tree of optimizer states? So it dictates the traversal behaviour for all subsequent operations. The root issue here is that Functors.functor(::Type{T}, x) currently has a monopoly over reconstruction of T from its functored representation. We can work around this downstream well enough by subsetting as discussed here and mergeing the changes back in afterwards before reconstruction, but longer term this (the inflexibility of the re closure and the fact that you have to carry around a closure at all) ought to be addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants