Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No rule for mean(f, x; dims) #1318

Closed
cossio opened this issue Oct 8, 2022 · 5 comments
Closed

No rule for mean(f, x; dims) #1318

cossio opened this issue Oct 8, 2022 · 5 comments
Labels
ChainRules adjoint -> rrule, and further integration needs adjoint missing rule

Comments

@cossio
Copy link
Contributor

cossio commented Oct 8, 2022

X = randn(4,5,6)
gradient(X) do X
    sum(mean(abs, X; dims=(1,2)))
end

I get a mutation error.

ERROR: Mutating arrays is not supported -- called copyto!(Array{Float64, 3}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Array{Float64, 3})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/array.jl:68
  [3] (::Zygote.var"#389#390"{Array{Float64, 3}})(#unused#::FillArrays.Fill{Float64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/array.jl:83
  [4] (::Zygote.var"#2322#back#391"{Zygote.var"#389#390"{Array{Float64, 3}}})(Δ::FillArrays.Fill{Float64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] Pullback
    @ ./broadcast.jl:871 [inlined]
  [6] Pullback
    @ ./broadcast.jl:868 [inlined]
  [7] Pullback
    @ ./broadcast.jl:864 [inlined]
  [8] (::typeof(∂(materialize!)))(Δ::FillArrays.Fill{Float64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/juliaup/julia-1.8.2+0.aarch64/share/julia/stdlib/v1.8/Statistics/src/Statistics.jl:181 [inlined]
 [10] (::typeof(∂(_mean)))(Δ::FillArrays.Fill{Float64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/juliaup/julia-1.8.2+0.aarch64/share/julia/stdlib/v1.8/Statistics/src/Statistics.jl:104 [inlined]
 [12] (::typeof(∂(#mean#1)))(Δ::FillArrays.Fill{Float64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/juliaup/julia-1.8.2+0.aarch64/share/julia/stdlib/v1.8/Statistics/src/Statistics.jl:104 [inlined]
 [14] (::typeof(∂(mean##kw)))(Δ::FillArrays.Fill{Float64, 3, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/GitHub/JointDiffRBMs.jl/test/regularize.jl:65 [inlined]
 [16] (::typeof(∂(#45)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [17] (::Zygote.var"#60#61"{typeof(∂(#45))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45
 [18] gradient(f::Function, args::Array{Float64, 3})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97
 [19] top-level scope
@cossio
Copy link
Contributor Author

cossio commented Oct 8, 2022

Culprit seems to be related to the mean. The following works fine:

X = randn(4,5,6)
gradient(X) do X
    sum(sum(abs, X; dims=(1,2)))
end

@cossio
Copy link
Contributor Author

cossio commented Oct 8, 2022

Seems specific to mean. The following works fine:

X = randn(4,5,6)
gradient(X) do X
    sum(sum(abs, X; dims=(1,2)))
end

@ToucheSir ToucheSir changed the title Unwarranted array mutation error No rule for mean(f, x; dims) Oct 8, 2022
@ToucheSir
Copy link
Member

Neither Zygote nor ChainRules has a rule for mean(f, x; dims). Probably wouldn't be too hard to add to the latter.

@ToucheSir ToucheSir added needs adjoint missing rule ChainRules adjoint -> rrule, and further integration labels Oct 8, 2022
@mcabbott
Copy link
Member

mcabbott commented Oct 8, 2022

This is JuliaDiff/ChainRules.jl#85, fixed by JuliaDiff/ChainRules.jl#615 but reverted in JuliaDiff/ChainRules.jl#619 due to #836

@CarloLucibello
Copy link
Member

closing as duplicate of #1128

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration needs adjoint missing rule
Projects
None yet
Development

No branches or pull requests

4 participants