diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 63ce70ae0..2cd03ce9e 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -65,9 +65,32 @@ function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...) return reshape(A, dims...), reshape_pullback end +##### +##### `permutedims` +##### + +function rrule(::typeof(permutedims), x::AbstractVector) + project = ProjectTo(x) + permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy)))) + return permutedims(x), permutedims_pullback_1 +end + +function rrule(::typeof(permutedims), x::AbstractArray, perm) + pr = ProjectTo(x) # projection restores e.g. transpose([1,2,3]) + permutedims_back_2(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent()) + return permutedims(x, perm), permutedims_back_2 +end + +function rrule(::typeof(PermutedDimsArray), x::AbstractArray, perm) + pr = ProjectTo(x) + permutedims_back_3(dy) = (NoTangent(), pr(permutedims(unthunk(dy), invperm(perm))), NoTangent()) + return PermutedDimsArray(x, perm), permutedims_back_3 +end + ##### ##### `repeat` ##### + function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs))) project_Xs = ProjectTo(xs) diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index ff3e93931..96be78198 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -42,6 +42,22 @@ end test_rrule(reshape, rand(4, 5), 2, :) end +@testset "permutedims + PermutedDimsArray" begin + test_rrule(permutedims, rand(5)) + + test_rrule(permutedims, rand(3, 4), (2, 1)) + test_rrule(permutedims, Diagonal(rand(5)), (2, 1)) + # Note BTW that permutedims(Diagonal(rand(5))) does not use the rule at all + + @test invperm((3, 1, 2)) != (3, 1, 2) + test_rrule(permutedims, rand(3, 4, 5), (3, 1, 2); check_inferred=VERSION>=v"1.1") + + @test_skip test_rrule(PermutedDimsArray, rand(3, 4, 5), (3, 1, 2)) + x = rand(2, 3, 4) + dy = rand(4, 2, 3) + @test rrule(permutedims, x, (3, 1, 2))[2](dy)[2] == rrule(PermutedDimsArray, x, (3, 1, 2))[2](dy)[2] +end + @testset "repeat" begin test_rrule(repeat, rand(4, ))