Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux.chunk for multi-dimensional arrays #1841

Closed
theabhirath opened this issue Jan 21, 2022 · 2 comments · Fixed by JuliaML/MLUtils.jl#47
Closed

Flux.chunk for multi-dimensional arrays #1841

theabhirath opened this issue Jan 21, 2022 · 2 comments · Fixed by JuliaML/MLUtils.jl#47

Comments

@theabhirath
Copy link
Member

As of now, the utility function chunk works fine for 1D arrays but seems to flatten out multidimensional arrays:

julia> x = rand(Int8, 5, 3, 2)
5×3×2 Array{Int8, 3}:
[:, :, 1] =
   -4   -92   106
   30    37   102
 -123  -119   -22
  114   -35  -102
   73   -87   104

[:, :, 2] =
 -101   33    84
  -97  101    -1
   66   19    43
    0   89  -101
   72  -75    37

julia> Flux.chunk(x, 3)
3-element Vector{SubArray{Int8, 1, Vector{Int8}, Tuple{UnitRange{Int64}}, true}}:
 [-4, 30, -123, 114, 73, -92, 37, -119, -35, -87]
 [106, 102, -22, -102, 104, -101, -97, 66, 0, 72]
 [33, 101, 19, 89, -75, 84, -1, 43, -101, 37]

I would like to get chunks of the array along a specific dimension without having to manually fix the shapes. Is it possible for this function to work in a way similar to torch.chunk(input, chunks, dim) i.e. take an additional input for the dimension along which the chunks are to be formed and return the arrays with the shapes unchanged along the other dimensions?

@mcabbott
Copy link
Member

You can do this:

julia> dl = Flux.Data.DataLoader(rand(Int8, 5, 3, 17); batchsize=8);

julia> map(size, dl)
3-element Vector{Tuple{Int64, Int64, Int64}}:
 (5, 3, 8)
 (5, 3, 8)
 (5, 3, 1)

In general I wonder if Flux.jl should have fewer such utility functions, and leave them to other packages like SplitApplyCombine.jl. Although it's not immediately clear they have exactly what you ask. It is easy to write yourself though:

julia> chunk(A, k::Int; dims::Int) = (selectdim(A, dims, i) for i in Iterators.partition(axes(A,dims), cld(size(A,dims), k)));

julia> chunk(x, 3; dims=1);

julia> map(size, ans)
3-element Vector{Tuple{Int64, Int64, Int64}}:
 (2, 3, 2)
 (2, 3, 2)
 (1, 3, 2)

@darsnack
Copy link
Member

darsnack commented Feb 8, 2022

This is a worthwhile addition for use in multi-head attention, so I suggest that we accept a PR so long as Flux has chunk at all. MLUtils.jl will incorporate a lot of these utilities, and we can deprecate them all wholesale when that package is ready.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants