Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient() fails on array mutation for mean(f, x; dims) #1128

Closed
staticfloat opened this issue Dec 1, 2021 · 3 comments · Fixed by JuliaDiff/ChainRules.jl#615
Closed

gradient() fails on array mutation for mean(f, x; dims) #1128

staticfloat opened this issue Dec 1, 2021 · 3 comments · Fixed by JuliaDiff/ChainRules.jl#615
Labels
needs adjoint missing rule

Comments

@staticfloat
Copy link
Contributor

If you provide both an element-wise function f and a dimension specification, mean() apparently causes array mutation, which breaks Zygote's ability to differentiate:

julia> using Zygote, Statistics
       x = randn(3, 3)
       Zygote.gradient(Params([x])) do
           sum(mean(abs2, x, dims=1))
       end
ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#441#442"{Matrix{Float64}})(#unused#::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/array.jl:74
  [3] (::Zygote.var"#2330#back#443"{Zygote.var"#441#442"{Matrix{Float64}}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] (::typeof((materialize!)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
  [8] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:181 [inlined]
  [9] (::typeof((_mean)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [10] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:104 [inlined]
 [11] (::typeof((#mean#1)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [12] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:104 [inlined]
 [13] (::typeof((mean##kw)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./REPL[14]:4 [inlined]
 [15] (::typeof((#17)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#89#90"{Params, typeof((#17)), Zygote.Context})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:356
 [17] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:76
 [18] top-level scope
    @ REPL[14]:3

Looking through the adjoints for mean() defined in lib/array.jl, I would guess that the fact that I'm passing abs2 in for f causes Zygote's implementation to be skipped altogether, and then the dims kwarg causes us to go down a bad path that involves array mutation. I was going to submit a PR to create a new @adjoint definition for one that includes f, but I don't know how to get the adjoint of a user-provided function.

@mzgubic
Copy link
Collaborator

mzgubic commented Dec 1, 2021

but I don't know how to get the adjoint of a user-provided function

You could use Zygote.pullback to AD through it, which will get an adjoint if it exists. A PR to ChainRules would be well received, see JuliaDiff/ChainRules.jl#85

@mcabbott
Copy link
Member

mcabbott commented Dec 1, 2021

The low-tech way to implement this is to turn it into broadcasting, as is currently done for sum(::Function, ::CuArray) here:

https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L280-L283

@oxinabox
Copy link
Member

reopened as i had to revert the fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs adjoint missing rule
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants