Skip to content

Commit

Permalink
Indexing (#655)
Browse files Browse the repository at this point in the history
* move and rename _zerolike_writeat, NFC

* simplify, use it for getindex, tests

* add unsafe_getindex too

* tidy, make weird types work via _setindex_zero

* fix view & its zero-arrays

* test unsafe_getindex

* handle indexing of GPU arrays

* suggested changes

* restore some gpu tests

* avoid the error

* in fact, mystery errors persist
  • Loading branch information
mcabbott authored Aug 12, 2022
1 parent 4c3a869 commit 77ef0eb
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 92 deletions.
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.43.2"
version = "1.44.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[compat]
Adapt = "3.4.0"
ChainRulesCore = "1.15.3"
ChainRulesTestUtils = "1.5"
Compat = "3.42.0, 4"
Expand All @@ -30,7 +32,6 @@ StructArrays = "0.6.11"
julia = "1.6"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand All @@ -40,4 +41,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
3 changes: 2 additions & 1 deletion src/ChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module ChainRules

using Adapt: adapt
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using ChainRulesCore
using Compat
using Distributed
using GPUArraysCore: AbstractGPUArrayStyle
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle
using IrrationalConstants: logtwo, logten
using LinearAlgebra
using LinearAlgebra.BLAS
Expand Down
52 changes: 1 addition & 51 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,64 +515,14 @@ for findm in (:findmin, :findmax)

@eval function rrule(::typeof($findm), x::AbstractArray; dims=:)
y, ind = $findm(x; dims=dims)
project = ProjectTo(x)
# This pullback is a lot like the one for getindex. Ideally they would probably be combined?
function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing)
dy isa AbstractZero && return (NoTangent(), NoTangent())
x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind))
x_ithunk = InplaceableThunk(x_thunk) do dx
if dims isa Colon
view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy))
else
view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0
end
dx
end
return (NoTangent(), x_ithunk)
return (NoTangent(), thunked_∇getindex(x, dy, ind),)
end
return (y, ind), $findm_pullback
end
end

# This function is roughly `setindex!(zero(x), dy, inds...)`:

function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...)
_zero_fill = eltype(dy) == Any ? 0 : zero(eltype(dy))

# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
# allow `eltype(dy)`, nor does it work for many structured matrices.
dx = fill!(similar(x, eltype(dy), axes(x)), _zero_fill)
view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
dx
end
function _zerolike_writeat(x::AbstractArray, dy, dims, inds...)
# Since we have `x`, we can also handle arrays of arrays.
dx = map(zero, x)
if dims isa Colon
view(dx, inds...) .= Ref(dy)
else
view(dx, inds...) .= dy
end
dx
end

# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
# these rules are the reason it takes a `dims` argument.

function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...)
end

function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...)
z = _zerolike_writeat(x, dy, dims, inds...)
function _zerolike_writeat_pullback(dz)
dx = sum(view(unthunk(dz), inds...); dims=dims)
nots = map(_ -> NoTangent(), inds)
return (NoTangent(), NoTangent(), dx, NoTangent(), nots...)
end
return z, _zerolike_writeat_pullback
end

# These rules for `maximum` pick the same subgradient as `findmax`:

function frule((_, ẋ), ::typeof(maximum), x; dims=:)
Expand Down
20 changes: 20 additions & 0 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
return y, map_pullback
end

#####
##### `task_local_storage`
#####

# Called by `@allowscalar` from GPUArrays

ChainRules.@non_differentiable task_local_storage(key::Any)
ChainRules.@non_differentiable task_local_storage(key::Any, value::Any)

function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value)
y, back = task_local_storage(key, value) do
rrule_via_ad(config, body)
end
function task_local_storage_pullback(dy)
dbody = only(back(dy))
return (NoTangent(), dbody, NoTangent(), NoTangent())
end
return y, task_local_storage_pullback
end
147 changes: 126 additions & 21 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,38 +52,111 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon)
return x, getindex_back_4
end


#####
##### getindex
##### getindex(::AbstractArray)
#####

function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...)
return x[inds...], ẋ[inds...]
end

function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
# removes any logical indexing, CartesianIndex etc
# leaving us just with a tuple of Int, Arrays of Int and Ranges of Int
function rrule(::typeof(getindex), x::AbstractArray, inds...)
function getindex_pullback(dy)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
end
return x[inds...], getindex_pullback
end

function thunked_∇getindex(x, dy, inds...)
return InplaceableThunk(
dx -> ∇getindex!(dx, unthunk(dy), Base.to_indices(x, inds)...),
@thunk(∇getindex(x, unthunk(dy), inds...)),
)
end

"""
∇getindex(x, dy, inds...)
For the `rrule` of `y = x[inds...]`, this function is roughly
`setindex(zero(x), dy, inds...)`, returning the array `dx`.
Differentiable. Includes `ProjectTo(x)(dx)`.
"""
function ∇getindex(x::AbstractArray, dy, inds...)
# `to_indices` removes any logical indexing, colons, CartesianIndex etc,
# leaving just Int / AbstractVector of Int
plain_inds = Base.to_indices(x, inds)
y = getindex(x, plain_inds...)
function getindex_pullback(ȳ)
function getindex_add!(Δ)
# this a optimizes away for simple cases
for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...))
Δ[ii...] += ȳ_ii
end
return Δ
end
dx = _setindex_zero(x, dy, plain_inds...)
∇getindex!(dx, dy, plain_inds...)
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
end

"""
_setindex_zero(x, dy, inds...)
= InplaceableThunk(
getindex_add!,
@thunk(getindex_add!(zero(x))),
)
īnds = broadcast(Returns(NoTangent()), inds)
return (NoTangent(), x̄, īnds...)
This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`,
and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what
`∇getindex` does next.
It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
allow `eltype(dy)`, nor does it work for many structured matrices.
"""
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), ZeroTangent())
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), ZeroTangent())
function _setindex_zero(x::AbstractArray, dy, inds::Integer...)
# This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
# but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors
T = Union{typeof(dy), ZeroTangent}
return fill!(similar(x, T, axes(x)), ZeroTangent())
end
function _setindex_zero(x::AbstractArray, dy, inds...)
T = Union{eltype(dy), ZeroTangent}
return fill!(similar(x, T, axes(x)), ZeroTangent())
end
ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...)

function ∇getindex!(dx::AbstractArray, dy, inds::Integer...)
view(dx, inds...) .+= Ref(dy)
return dx
end
function ∇getindex!(dx::AbstractArray, dy, inds...)
view(dx, inds...) .+= dy
return dx
end

# Allow for second derivatives, by writing rules for `∇getindex`:

function frule((_, _, dẏ), ::typeof(∇getindex), x, dy, inds...)
return ∇getindex(x, dy, inds...), ∇getindex(x, dẏ, inds...)
end

function rrule(::typeof(∇getindex), x, dy, inds...)
z = ∇getindex(x, dy, inds...)
function ∇getindex_pullback(dz)
d2y = getindex(unthunk(dz), inds...)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), NoTangent(), ProjectTo(dy)(d2y), nots...)
end
return z, ∇getindex_pullback
end

# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
# To avoid this, copy everything back to the CPU.
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:

return y, getindex_pullback
function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...)
view(dx, inds...) .+= Ref(dy)
return dx
end
function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...)
view(dx, inds...) .+= dy
return dx
end
function ∇getindex!(dx::AbstractGPUArray, dy, inds...)
dx_cpu = adapt(Array, dx)
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
copyto!(dx, dx_cpu)
return dx
end

#####
Expand Down Expand Up @@ -117,6 +190,23 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...)
return view(x, inds...), view(ẋ, inds...)
end

function rrule(::typeof(view), x::AbstractArray, inds...)
function view_pullback(dy)
nots = map(Returns(NoTangent()), inds)
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
end
return view(x, inds...), view_pullback
end

function rrule(::typeof(view), x::AbstractArray, i::Integer, jkl::Integer...)
# This case returns a zero-dim array, unlike getindex. So we fool ∇getindex:
function view_pullback_0(dy)
nots = map(Returns(NoTangent()), (i, jkl...))
return (NoTangent(), thunked_∇getindex(x, dy, i:i, jkl...), nots...)
end
return view(x, i, jkl...), view_pullback_0
end

#####
##### setindex!
#####
Expand All @@ -125,6 +215,21 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...)
return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...)
end

#####
##### unsafe_getindex
#####

# This is called by e.g. `iterate(1:0.1:2)`,
# and fixes https://github.com/FluxML/Zygote.jl/issues/1247
# Only needs to accept AbstractRange, but AbstractVector makes testing easier.

function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
return Base.unsafe_getindex(x, i), getindex(ẋ, i)
end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
return rrule_via_ad(cfg, getindex, x, i)
end

#####
##### `eachslice` and friends
Expand Down
12 changes: 4 additions & 8 deletions src/rulesets/Base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
function sortslices_pullback(dy)
# No actual need to zero this, and if you didn't, then you could widen eltype
# Also, you could use similar(dy) here not x, same size?
dx = _zerolike_writeat(x, unthunk(dy), (), inds...)
return (NoTangent(), ProjectTo(x)(dx))
return (NoTangent(), ∇getindex(x, unthunk(dy), inds...))
end
return x[inds...], sortslices_pullback
end
Expand Down Expand Up @@ -94,12 +91,11 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:)
mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1)
keep = map(I -> I[1], findall(mask))
if dims isa Colon
# The function `_zerolike_writeat` allows second derivatives.
# Should perhaps eventually be shared with `getindex`.
dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x)
# The function `∇getindex` allows second derivatives.
dx = reshape(∇getindex(vec(x), vec(dy), keep), axes_x) ## TODO understand again why vec!
else
inds = ntuple(d -> d==dims ? keep : (:), length(axes_x))
dx = _zerolike_writeat(x, dy, (), inds...)
dx = ∇getindex(x, dy, inds...)
end
return (NoTangent(), ProjectTo(x)(dx))
end
Expand Down
9 changes: 1 addition & 8 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,7 @@ end
@test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2])
test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()))
test_rrule(findmin, rand(3,4), fkwargs=(dims=2,))

# Second derivatives
test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
@test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9)
y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)])
@test y == [0 0; 5 5]
@test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent())
test_rrule(findmin, rand(3,4), fkwargs=(dims=(1,2),))
end

@testset "$imum" for imum in [maximum, minimum]
Expand Down
Loading

2 comments on commit 77ef0eb

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/66148

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.44.0 -m "<description of version>" 77ef0eb15fdd207028c059927f2456819d62df8c
git push origin v1.44.0

Please sign in to comment.