-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rule for permutedims
#559
Rule for permutedims
#559
Conversation
@mcabbott could you add a test for / specialise on |
function rrule(::typeof(PermutedDimsArray), x::AbstractArray, perm) | ||
pr = ProjectTo(x) | ||
permutedims_back_3(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent()) | ||
return PermutedDimsArray(x, perm), permutedims_back_3 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, the logic of doing permutedims
in reverse pass is this. Calling PermutedDimsArrray forwards means you think the next operation will be OK with this, something like broadcasting, not like reshape + BLAS. But the previous operation might not be, we don't know that its gradient is going to be happy with the lazy container. So I think it's safest to copy.
These could all have Inplaceabelthunks, with permutedims!
, but perhaps we kick that can down the road until they are useful.
I added a test. I think the one-arg |
@willtebbutt should give final approval. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a test. I think the one-arg permutedims(::Diagonal) will never call the rule, thus can't be tested. (Whereas permutedims(::Matrix) calls permutedims(::Matrix, (2,1)) and then hits the rule.)
I agree with this, and the test in the 2-arg case looks reasonable. I'm happy for this to be merged (subject to patch bump and CI passing ofc).
It seems this PR broke DistributionsAD tests, and hence the integration tests in Distributions: https://github.com/JuliaStats/Distributions.jl/runs/4797321095?check_suite_focus=true |
Sorry about the break. That's bizarre, how does it get them backwards? Or does it just have completely the wrong arguments, And here, do you see anything obviously wrong? Looking at it now, I believe it should probably be Edit: Maybe I see:
That's not right. And maybe it's triggered only in weird edge cases, type unstable code in AD or something, which is why this package's own tests (and Zygote's) didn't see it? |
Closes #523