-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve cat
rules
#660
Improve cat
rules
#660
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this. My (limited) understanding of what is going on here is:
Scalar indexing operations are very slow on the GPU. The design decision of some GPU package is that rather than being silently slow, the code breaks. Scalar indexing must be turned on explicitly for each operation, which is what @allowscalar
does.
We are ok with doing that here because it won't happen many times: it will happen N times, where N is the number of array arguments in the Xcat
call. Usually we want to avoid doing this M times, where M is the number of elements in an array.
I'm curious: what is the usual approach of dealing with this? Rewriting code in a more efficient way, which does array multiplication rather than scalar indexing? And if that can't be done like in this case, just eat the cost with @allowscalar?
Approving subject to not being totally off the mark in the above :)
src/rulesets/Base/array.jl
Outdated
sum(view(dY, ind...)) | ||
end | ||
dX = @allowscalar dY[ind...] | ||
# Here InplaceableThunk breaks @inferred, removed for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this still relevant? I guess it broke inference because of the if
statement above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, have now restored InplaceableThunk
Yes. In my rough understanding, any The scalar case for julia> using JLArrays
julia> vcat(jl([1,2]),3)
3-element JLArray{Int64, 1}:
1
2
3
julia> vcat(0,jl([1,2]),3)
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000000010ccb0010.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/ZBmfM/src/GPUArraysCore.jl:90
4-element Vector{Int64}:
0
1
2
3 |
This should allow FluxML/Zygote.jl#1277, by using
@allowscalar
in front of indexing.Also inserts
Base.require_one_based_indexing
as a guard against upgrades to Base or OffsetArrays.