Skip to content

Commit

Permalink
Improve cat rules (#660)
Browse files Browse the repository at this point in the history
* use allowscalar in cat rules

* use require_one_based_indexing

* restore InplaceableThunk
  • Loading branch information
mcabbott authored Aug 15, 2022
1 parent 77ef0eb commit 63cc4e0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.44.0"
version = "1.44.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
2 changes: 1 addition & 1 deletion src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
using ChainRulesCore
using Compat
using Distributed
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle, @allowscalar
using IrrationalConstants: logtwo, logten
using LinearAlgebra
using LinearAlgebra.BLAS
Expand Down
43 changes: 19 additions & 24 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ _catsize(x::AbstractArray) = size(x)

function rrule(::typeof(hcat), Xs...)
Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
Base.require_one_based_indexing(Y)
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
sizes = map(_catsize, Xs) # this avoids closing over Xs
project_Xs = map(ProjectTo, Xs)
Expand All @@ -233,15 +234,10 @@ function rrule(::typeof(hcat), Xs...)
d > ndimsX ? 1 : (:)
end
end
dX = if ndimsX > 0
# Here InplaceableThunk breaks @inferred, removed for now
# InplaceableThunk(dX -> dX .+= view(dY, ind...), @thunk(dY[ind...]))
dY[ind...]
else
# This is a hack to perhaps avoid GPU scalar indexing
sum(view(dY, ind...))
end
return project(dX)
InplaceableThunk(
dX -> dX .+= view(dY, ind...),
@thunk project(@allowscalar dY[ind...])
)
end
return (NoTangent(), dXs...)
end
Expand All @@ -253,6 +249,8 @@ function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVecto
end

function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
Y = reduce(hcat, As)
Base.require_one_based_indexing(Y)
widths = map(A -> size(A,2), As)
function reduce_hcat_pullback_2(dY)
hi = Ref(0)
Expand All @@ -263,7 +261,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe
end
return (NoTangent(), NoTangent(), dAs)
end
return reduce(hcat, As), reduce_hcat_pullback_2
return Y, reduce_hcat_pullback_2
end

function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVector})
Expand All @@ -286,6 +284,7 @@ end

function rrule(::typeof(vcat), Xs...)
Y = vcat(Xs...)
Base.require_one_based_indexing(Y)
ndimsY = Val(ndims(Y))
sizes = map(_catsize, Xs)
project_Xs = map(ProjectTo, Xs)
Expand All @@ -303,13 +302,10 @@ function rrule(::typeof(vcat), Xs...)
d > ndimsX ? 1 : (:)
end
end
dX = if ndimsX > 0
# InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
dY[ind...]
else
sum(view(dY, ind...))
end
return project(dX)
InplaceableThunk(
dX -> dX .+= view(dY, ind...),
@thunk project(@allowscalar dY[ind...])
)
end
return (NoTangent(), dXs...)
end
Expand All @@ -322,6 +318,7 @@ end

function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
Y = reduce(vcat, As)
Base.require_one_based_indexing(Y)
ndimsY = Val(ndims(Y))
heights = map(A -> size(A,1), As)
function reduce_vcat_pullback(dY)
Expand Down Expand Up @@ -349,6 +346,7 @@ end

function rrule(::typeof(cat), Xs...; dims)
Y = cat(Xs...; dims=dims)
Base.require_one_based_indexing(Y)
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
ndimsY = Val(ndims(Y))
sizes = map(_catsize, Xs)
Expand All @@ -368,13 +366,10 @@ function rrule(::typeof(cat), Xs...; dims)
for d in cdims
prev[d] += get(sizeX, d, 1)
end
dX = if ndimsX > 0
# InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...))
dY[index...]
else
sum(view(dY, index...))
end
return project(dX)
InplaceableThunk(
dX -> dX .+= view(dY, index...),
@thunk project(@allowscalar dY[index...])
)
end
return (NoTangent(), dXs...)
end
Expand Down

2 comments on commit 63cc4e0

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/66279

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.44.1 -m "<description of version>" 63cc4e0387817747f7ed04e9ffb2321e9717b7f4
git push origin v1.44.1

Please sign in to comment.