Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rule for permutedims #559

Merged
merged 5 commits into from
Jan 12, 2022
Merged

Rule for permutedims #559

merged 5 commits into from
Jan 12, 2022

Conversation

mcabbott
Copy link
Member

Closes #523

@willtebbutt
Copy link
Member

@mcabbott could you add a test for / specialise on Diagonal matrices? There are two specialised methods, both of which are essentially the identity function, so it would be good to be confident that these rules work in that case.

Comment on lines +84 to +88
function rrule(::typeof(PermutedDimsArray), x::AbstractArray, perm)
pr = ProjectTo(x)
permutedims_back_3(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent())
return PermutedDimsArray(x, perm), permutedims_back_3
end
Copy link
Member Author

@mcabbott mcabbott Jan 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, the logic of doing permutedims in reverse pass is this. Calling PermutedDimsArrray forwards means you think the next operation will be OK with this, something like broadcasting, not like reshape + BLAS. But the previous operation might not be, we don't know that its gradient is going to be happy with the lazy container. So I think it's safest to copy.

These could all have Inplaceabelthunks, with permutedims!, but perhaps we kick that can down the road until they are useful.

@mcabbott
Copy link
Member Author

specialise on Diagonal matrices? There are two specialised methods, both of which are essentially the identity function,

I added a test. I think the one-arg permutedims(::Diagonal) will never call the rule, thus can't be tested. (Whereas permutedims(::Matrix) calls permutedims(::Matrix, (2,1)) and then hits the rule.)

@oxinabox
Copy link
Member

@willtebbutt should give final approval.
(I am unsubscribing, ping me if i am needed)

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test. I think the one-arg permutedims(::Diagonal) will never call the rule, thus can't be tested. (Whereas permutedims(::Matrix) calls permutedims(::Matrix, (2,1)) and then hits the rule.)

I agree with this, and the test in the 2-arg case looks reasonable. I'm happy for this to be merged (subject to patch bump and CI passing ofc).

@mcabbott mcabbott merged commit cb816d0 into JuliaDiff:main Jan 12, 2022
@mcabbott mcabbott deleted the permutedims branch January 12, 2022 18:45
@devmotion
Copy link
Member

It seems this PR broke DistributionsAD tests, and hence the integration tests in Distributions: https://github.com/JuliaStats/Distributions.jl/runs/4797321095?check_suite_focus=true

@mcabbott
Copy link
Member Author

mcabbott commented Jan 13, 2022

Sorry about the break. That's bizarre, how does it get them backwards? Or does it just have completely the wrong arguments, PermutedDimsArray(data::Vector{Float64}, perm::Matrix{Float64})? Does ::Type{TuringDenseMvNormal}, ::Vector{Float64}, ::Matrix{Float64}) look like a sensible type?

And here, do you see anything obviously wrong? Looking at it now, I believe it should probably be function rrule(::Type{<:PermutedDimsArray}, x::AbstractArray, perm), but locally I'm not sure this makes a difference?

Edit: Maybe I see:

julia> ff(::typeof(PermutedDimsArray), x) = x^2;

julia> ff(PermutedDimsArray, 3)
9

julia> ff(sin, 3)
ERROR: MethodError: no method matching ff(::typeof(sin), ::Int64)
Closest candidates are:
  ff(::UnionAll, ::Any) at REPL[23]:1
Stacktrace:
 [1] top-level scope
   @ REPL[25]:1

julia> methods(ff)
# 1 method for generic function "ff":
[1] ff(::UnionAll, x) in Main at REPL[23]:1

That's not right. And maybe it's triggered only in weird edge cases, type unstable code in AD or something, which is why this package's own tests (and Zygote's) didn't see it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Port rules for permutedims
5 participants