-
-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow gradients in fmap #1
Conversation
Any chance of some documentation alongside this PR explaining what the purpose of the two-argument version of Additionally, why do we need a separate |
Its to make Optimisers.jl work by returning a slightly trained version of the model. The tuple function is so that we return a tuple of layers in a chain. I should revisit to see if its necessary still but otherwise this should be merged so we can release Optimisers.jl |
I agree that this should probably go in, but a comment of some kind explaining what you've said above would be great. |
bump! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look get. Perhaps some tests and a patch bump?
Woops, meant to say LGTM once the tests land. |
Allows writing something like Functors.fmap(x, y) do x, y
isnothing(x) && return y
isnothing(y) && return x
f(x,y)
end where x and y are basically NamedTuples out of Zygote, or a mix of the actual object tree (think Chain, model struct etc) and the corresponding gradients. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
@@ -44,10 +44,29 @@ Equivalent to `functor(x)[1]`. | |||
""" | |||
children(x) = functor(x)[1] | |||
|
|||
function functor_tuple(f, x::Tuple, dx::Tuple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think x
is a tuple
function functor_tuple(f, x::Tuple, dx::Tuple) | |
function functor_tuple(f, x, dx::Tuple) |
function _default_walk(f, x, dx) | ||
func, re = functor(x) | ||
map(func, dx) do x, x̄ | ||
# functor_tuple(f, x, x̄) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work with the other change?
@DhairyaLGandhi this already has my approval, so feel free to merge any time :) |
No description provided.