Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dims for chunk and test_zygote #47

Merged
merged 13 commits into from
Feb 10, 2022
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@ authors = ["Carlo Lucibello <[email protected]> and contributors"]
version = "0.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
ChainRulesCore = "1.0"
ShowCases = "0.1"
StatsBase = "0.33"
julia = "1.6"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["SparseArrays", "Test"]
test = ["ChainRulesTestUtils", "SparseArrays", "Test", "Zygote"]
4 changes: 4 additions & 0 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ using ShowCases: ShowLimit
import StatsBase: sample
using Base: @propagate_inbounds
using Random: AbstractRNG, shuffle!, GLOBAL_RNG
import ChainRulesCore: rrule
using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
NoTangent, ZeroTangent, ProjectTo


include("observation.jl")
export numobs,
Expand Down
91 changes: 82 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,13 @@ julia> unstack([1 3 5 7; 2 4 6 8], dims=2)
unstack(xs; dims::Int) = [copy(selectdim(xs, dims, i)) for i in 1:size(xs, dims)]

"""
chunk(x, n)
chunk(x, n; [dims])

Split `x` into `n` parts.
Split `x` into `n` parts. The parts contain the same number of elements
except possibly for the last one that can be smaller.

If `x` is an array, `dims` can be used to specify along which dimension to
split (defaults to the last dimension).

# Examples

Expand All @@ -134,14 +138,83 @@ julia> chunk(1:10, 3)
5:8
9:10

julia> chunk(collect(1:10), 3)
3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}:
[1, 2, 3, 4]
[5, 6, 7, 8]
[9, 10]
julia> x = reshape(collect(1:20), (5, 4))
5×4 Matrix{Int64}:
1 6 11 16
2 7 12 17
3 8 13 18
4 9 14 19
5 10 15 20

julia> xs = chunk(x, 2, dims=1)
2-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}:
[1 6 11 16; 2 7 12 17; 3 8 13 18]
[4 9 14 19; 5 10 15 20]

julia> xs[1]
3×4 view(::Matrix{Int64}, 1:3, :) with eltype Int64:
1 6 11 16
2 7 12 17
3 8 13 18
```
"""
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
chunk(x, n::Int) = collect(Iterators.partition(x, ceil(Int, length(x) / n)))

function chunk(x::AbstractArray, n::Int; dims::Int=ndims(x))
idxs = _partition_idxs(x, n, dims)
[selectdim(x, dims, i) for i in idxs]
end

# Zygote errors if not iterating over collected partitions in [selectdim(x, dims, i) for i in ids]
function _partition_idxs(x, n, dims)
bs = ceil(Int, size(x, dims) / n)
collect(Iterators.partition(axes(x, dims), bs))
end

@non_differentiable _partition_idxs(x...)

function rrule(::typeof(chunk), x::AbstractArray, n::Int; dims::Int=ndims(x))
# this is the implementation of chunk
idxs = _partition_idxs(x, n, dims)
y = [selectdim(x, dims, i) for i in idxs]
valdims = Val(dims)
chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims), NoTangent())

return y, chunk_pullback
end

# Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77
function ∇chunk(dys, x::AbstractArray, idxs, vd::Val{dim}) where {dim}
i1 = findfirst(dy -> !(dy isa AbstractZero), dys)
if i1 === nothing # all slices are Zero!
return _zero_fill!(similar(x, float(eltype(x))))
end
T = promote_type(eltype(dys[i1]), eltype(x))
# The whole point of this gradient is that we can allocate one `dx` array:
dx = similar(x, T)
for (k, i) in enumerate(idxs)
slice = selectdim(dx, dim, i)
if dys[k] isa AbstractZero
_zero_fill!(slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3]
else
copyto!(slice, dys[k])
end
end
return ProjectTo(x)(dx)
end

_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx)))
_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx)

function rrule(::typeof(∇chunk), dys, x, idxs, vd::Val{dim}) where dim
n = length(dys)
function ∇∇chunk(dz_raw)
dz = unthunk(dz_raw)
cs = chunk(dz, n; dims=dim)
return (NoTangent(), collect(cs), NoTangent(), NoTangent(), NoTangent())
end
return ∇chunk(dys, x, idxs, vd), ∇∇chunk
end

"""
group_counts(x)
Expand Down Expand Up @@ -370,4 +443,4 @@ function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
end

ofeltype(x, y) = convert(float(eltype(x)), y)
epseltype(x) = eps(float(eltype(x)))
epseltype(x) = eps(float(eltype(x)))
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ using MLUtils.Datasets
using SparseArrays
using Random, Statistics
using Test
using ChainRulesTestUtils: test_rrule
using Zygote: ZygoteRuleConfig
using ChainRulesCore: rrule_via_ad

showcompact(io, x) = show(IOContext(io, :compact => true), x)

Expand Down Expand Up @@ -35,6 +38,8 @@ MLUtils.getobs(::CustomType, i::AbstractVector) = collect(i)

# --------------------------------------------------------------------

include("test_utils.jl")

# @testset "MLUtils.jl" begin

@testset "batchview" begin; include("batchview.jl"); end
Expand Down
5 changes: 5 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

function test_zygote(f, xs...; kws...)
config = ZygoteRuleConfig()
test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad)
end
24 changes: 24 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,30 @@ end
@test cs[1] == [1, 2, 3, 4]
@test cs[2] == [5, 6, 7, 8]
@test cs[3] == [9, 10]

cs = chunk(collect(1:10), 3)
@test length(cs) == 3
@test cs[1] == [1, 2, 3, 4]
@test cs[2] == [5, 6, 7, 8]
@test cs[3] == [9, 10]

x = reshape(collect(1:20), (5, 4))
cs = chunk(x, 2)
@test length(cs) == 2
cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10]
cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20]

# test gradient
test_zygote(chunk, rand(10), 3, check_inferred=false)

# indirect test of second order derivates
n = 2
dims = 2
x = rand(4, 5)
y = chunk(x, 2)
dy = randn!.(collect.(y))
idxs = MLUtils._partition_idxs(x, n, dims)
test_zygote(MLUtils.∇chunk, dy, x, idxs, Val(dims), check_inferred=false)
end

@testset "group_counts" begin
Expand Down