-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux.chunk
for multi-dimensional arrays
#1841
Comments
You can do this:
In general I wonder if Flux.jl should have fewer such utility functions, and leave them to other packages like SplitApplyCombine.jl. Although it's not immediately clear they have exactly what you ask. It is easy to write yourself though:
|
This is a worthwhile addition for use in multi-head attention, so I suggest that we accept a PR so long as Flux has |
As of now, the utility function
chunk
works fine for 1D arrays but seems to flatten out multidimensional arrays:I would like to get chunks of the array along a specific dimension without having to manually fix the shapes. Is it possible for this function to work in a way similar to
torch.chunk(input, chunks, dim)
i.e. take an additional input for the dimension along which the chunks are to be formed and return the arrays with the shapes unchanged along the other dimensions?The text was updated successfully, but these errors were encountered: