diff --git a/Project.toml b/Project.toml index 048ce79..9455364 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Carlo Lucibello and contributors"] version = "0.1.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" @@ -11,13 +12,16 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] +ChainRulesCore = "1.0" ShowCases = "0.1" StatsBase = "0.33" julia = "1.6" [extras] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["SparseArrays", "Test"] +test = ["ChainRulesTestUtils", "SparseArrays", "Test", "Zygote"] diff --git a/src/MLUtils.jl b/src/MLUtils.jl index b5805d4..8454231 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -6,6 +6,10 @@ using ShowCases: ShowLimit import StatsBase: sample using Base: @propagate_inbounds using Random: AbstractRNG, shuffle!, GLOBAL_RNG +import ChainRulesCore: rrule +using ChainRulesCore: @non_differentiable, unthunk, AbstractZero, + NoTangent, ZeroTangent, ProjectTo + include("observation.jl") export numobs, diff --git a/src/utils.jl b/src/utils.jl index 9bc6bde..2f5e2ca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -121,9 +121,13 @@ julia> unstack([1 3 5 7; 2 4 6 8], dims=2) unstack(xs; dims::Int) = [copy(selectdim(xs, dims, i)) for i in 1:size(xs, dims)] """ - chunk(x, n) + chunk(x, n; [dims]) -Split `x` into `n` parts. +Split `x` into `n` parts. The parts contain the same number of elements +except possibly for the last one that can be smaller. + +If `x` is an array, `dims` can be used to specify along which dimension to +split (defaults to the last dimension). # Examples @@ -134,14 +138,79 @@ julia> chunk(1:10, 3) 5:8 9:10 -julia> chunk(collect(1:10), 3) -3-element Vector{SubArray{Int64, 1, Vector{Int64}, Tuple{UnitRange{Int64}}, true}}: - [1, 2, 3, 4] - [5, 6, 7, 8] - [9, 10] +julia> x = reshape(collect(1:20), (5, 4)) +5×4 Matrix{Int64}: + 1 6 11 16 + 2 7 12 17 + 3 8 13 18 + 4 9 14 19 + 5 10 15 20 + +julia> xs = chunk(x, 2, dims=1) +2-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}: + [1 6 11 16; 2 7 12 17; 3 8 13 18] + [4 9 14 19; 5 10 15 20] + +julia> xs[1] +3×4 view(::Matrix{Int64}, 1:3, :) with eltype Int64: + 1 6 11 16 + 2 7 12 17 + 3 8 13 18 ``` """ -chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n))) +chunk(x, n::Int) = collect(Iterators.partition(x, ceil(Int, length(x) / n))) + +function chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) + idxs = _partition_idxs(x, n, dims) + [selectdim(x, dims, i) for i in idxs] +end + +function _partition_idxs(x, n, dims) + bs = ceil(Int, size(x, dims) / n) + Iterators.partition(axes(x, dims), bs) +end + +function rrule(::typeof(chunk), x::AbstractArray, n::Int; dims::Int=ndims(x)) + # this is the implementation of chunk + idxs = _partition_idxs(x, n, dims) + y = [selectdim(x, dims, i) for i in idxs] + valdims = Val(dims) + chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims), NoTangent()) + + return y, chunk_pullback +end + +# Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77 +function ∇chunk(dys, x::AbstractArray, idxs, vd::Val{dim}) where {dim} + i1 = findfirst(dy -> !(dy isa AbstractZero), dys) + if i1 === nothing # all slices are Zero! + return _zero_fill!(similar(x, float(eltype(x)))) + end + T = promote_type(eltype(dys[i1]), eltype(x)) + # The whole point of this gradient is that we can allocate one `dx` array: + dx = similar(x, T) + for (k, i) in enumerate(idxs) + slice = selectdim(dx, dim, i) + if dys[k] isa AbstractZero + _zero_fill!(slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3] + else + copyto!(slice, dys[k]) + end + end + return ProjectTo(x)(dx) +end + +_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx))) +_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx) + +function rrule(::typeof(∇chunk), dys, x, idxs, vd::Val{dim}) where dim + n = length(dys) + function ∇∇chunk(dz_raw) + dz = chunk(unthunk(dz_raw), n; dims=dim) + return (NoTangent(), dz, NoTangent(), NoTangent(), NoTangent()) + end + return ∇chunk(dys, x, idxs, vd), ∇∇chunk +end """ group_counts(x) @@ -370,4 +439,4 @@ function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) end ofeltype(x, y) = convert(float(eltype(x)), y) -epseltype(x) = eps(float(eltype(x))) \ No newline at end of file +epseltype(x) = eps(float(eltype(x))) diff --git a/test/runtests.jl b/test/runtests.jl index dab09d6..297e777 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,9 @@ using MLUtils.Datasets using SparseArrays using Random, Statistics using Test +using ChainRulesTestUtils: test_rrule +using Zygote: ZygoteRuleConfig +using ChainRulesCore: rrule_via_ad showcompact(io, x) = show(IOContext(io, :compact => true), x) @@ -35,6 +38,8 @@ MLUtils.getobs(::CustomType, i::AbstractVector) = collect(i) # -------------------------------------------------------------------- +include("test_utils.jl") + # @testset "MLUtils.jl" begin @testset "batchview" begin; include("batchview.jl"); end diff --git a/test/test_utils.jl b/test/test_utils.jl new file mode 100644 index 0000000..4db946a --- /dev/null +++ b/test/test_utils.jl @@ -0,0 +1,5 @@ + +function test_zygote(f, xs...; kws...) + config = ZygoteRuleConfig() + test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad) +end diff --git a/test/utils.jl b/test/utils.jl index 298b9fc..782d464 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -82,6 +82,30 @@ end @test cs[1] == [1, 2, 3, 4] @test cs[2] == [5, 6, 7, 8] @test cs[3] == [9, 10] + + cs = chunk(collect(1:10), 3) + @test length(cs) == 3 + @test cs[1] == [1, 2, 3, 4] + @test cs[2] == [5, 6, 7, 8] + @test cs[3] == [9, 10] + + x = reshape(collect(1:20), (5, 4)) + cs = chunk(x, 2) + @test length(cs) == 2 + cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10] + cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20] + + # test gradient + test_zygote(chunk, rand(10), 3, check_inferred=false) + + # indirect test of second order derivates + n = 2 + dims = 2 + x = rand(4, 5) + y = chunk(x, 2) + dy = randn!.(collect.(y)) + idxs = MLUtils._partition_idxs(x, n, dims) + test_zygote(MLUtils.∇chunk, dy, x, idxs, Val(dims), check_inferred=false) end @testset "group_counts" begin