-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gradient() fails on array mutation for mean(f, x; dims)
#1128
Comments
You could use |
The low-tech way to implement this is to turn it into broadcasting, as is currently done for https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L280-L283 |
reopened as i had to revert the fix |
If you provide both an element-wise function
f
and a dimension specification,mean()
apparently causes array mutation, which breaks Zygote's ability to differentiate:Looking through the adjoints for
mean()
defined inlib/array.jl
, I would guess that the fact that I'm passingabs2
in forf
causes Zygote's implementation to be skipped altogether, and then thedims
kwarg causes us to go down a bad path that involves array mutation. I was going to submit a PR to create a new@adjoint
definition for one that includesf
, but I don't know how to get the adjoint of a user-provided function.The text was updated successfully, but these errors were encountered: