Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cat rules #660

Merged
merged 3 commits into from
Aug 15, 2022
Merged

Improve cat rules #660

merged 3 commits into from
Aug 15, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 15, 2022

This should allow FluxML/Zygote.jl#1277, by using @allowscalar in front of indexing.

Also inserts Base.require_one_based_indexing as a guard against upgrades to Base or OffsetArrays.

Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this. My (limited) understanding of what is going on here is:

Scalar indexing operations are very slow on the GPU. The design decision of some GPU package is that rather than being silently slow, the code breaks. Scalar indexing must be turned on explicitly for each operation, which is what @allowscalar does.

We are ok with doing that here because it won't happen many times: it will happen N times, where N is the number of array arguments in the Xcat call. Usually we want to avoid doing this M times, where M is the number of elements in an array.

I'm curious: what is the usual approach of dealing with this? Rewriting code in a more efficient way, which does array multiplication rather than scalar indexing? And if that can't be done like in this case, just eat the cost with @allowscalar?

Approving subject to not being totally off the mark in the above :)

sum(view(dY, ind...))
end
dX = @allowscalar dY[ind...]
# Here InplaceableThunk breaks @inferred, removed for now
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still relevant? I guess it broke inference because of the if statement above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, have now restored InplaceableThunk

@mcabbott
Copy link
Member Author

it will happen N times, where N is the number of array arguments in the Xcat call. Usually we want to avoid doing this M times, where M is the number of elements in an array.

Yes. In my rough understanding, any cat gradient will involve N trips to the GPU. Hopefully N is small. The error is to warn us about code which accidentally makes M trips.

The scalar case for cat is pretty weird, as you have to get lucky on the forward pass. IDK if this is something anyone really does, or just something Chris once wanted and got included into Zygote's tests. (Exactly one test, of course.)

julia> using JLArrays

julia> vcat(jl([1,2]),3)
3-element JLArray{Int64, 1}:
 1
 2
 3

julia> vcat(0,jl([1,2]),3)
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000000010ccb0010.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/ZBmfM/src/GPUArraysCore.jl:90
4-element Vector{Int64}:
 0
 1
 2
 3

@mcabbott mcabbott merged commit 63cc4e0 into JuliaDiff:main Aug 15, 2022
@mcabbott mcabbott deleted the cat2 branch August 15, 2022 16:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants