-
-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: use GPUArrays, also for testing #9
Conversation
src/array.jl
Outdated
@@ -90,7 +90,7 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri | |||
|
|||
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L) | |||
|
|||
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}() | |||
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, T}}) where {N, T <: AbstractGPUArray} = Base.BroadcastStyle(T) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? The number of dimensions for T would be <N.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some confusion: The N
here is what's called M
or N+1
elsewhere in this file. The 4th type parameter, not the 3rd. So this line does change the result, although maybe it's not tested?
Is there a way around that?
One hack I can think of is relying on all ArrayStyles to be like CuArrayStyle{N}
and have their N first. It seems pretty unlikely that any AbstractGPUArray is going to deviate from that pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That or some hack to swap out the ndims of the index array type to N
. Then forwarding to BroadcastStyle
would just do the right thing. I'm not sure which is more stable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See what you think of ffb6f00
Last I checked, |
Ok, I have not run into these limitations yet. It does appear to catch some problems here, e.g. FluxML/Flux.jl#1905 . It also causes some problems, as Zygote has special paths for CuArray not AbstractGPUArray. But perhaps that should be changed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, let's get GPU CI set up next
I don't have a merge button here, BTW. |
Looks like the default permissions are pretty strict, you should now :) |
This PR does two things.
First it replaces CUDA with GPUArrays, mostly to save startup time:
Second, it wants to run tests using a fake GPUArray when CUDA isn't around. I think this ought to catch many scalar indexing errors. Not entirely sure how official this JLArray code is, we could check. The actual tests are from Flux, minimally adapted.
Both of these might be worth doing more widely. But this little package (which I only saw today) might be an easy place to start. Thoughts?