Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: use GPUArrays, also for testing #9

Merged
merged 4 commits into from
Mar 12, 2022
Merged

Conversation

mcabbott
Copy link
Member

This PR does two things.

First it replaces CUDA with GPUArrays, mostly to save startup time:

julia> @time using CUDA
  3.515312 seconds (9.57 M allocations: 654.933 MiB, 2.48% gc time, 16.35% compilation time)
  2.117964 seconds (4.08 M allocations: 328.683 MiB, 24.88% compilation time)  # after using GPUArrays

julia> @time using GPUArrays
  1.323729 seconds (5.49 M allocations: 326.357 MiB, 4.41% compilation time)  # on a fresh start
  0.641342 seconds (1.05 M allocations: 54.070 MiB, 5.79% gc time, 99.95% compilation time)  # after using CUDA

Second, it wants to run tests using a fake GPUArray when CUDA isn't around. I think this ought to catch many scalar indexing errors. Not entirely sure how official this JLArray code is, we could check. The actual tests are from Flux, minimally adapted.

Both of these might be worth doing more widely. But this little package (which I only saw today) might be an easy place to start. Thoughts?

src/array.jl Outdated
@@ -90,7 +90,7 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, T}}) where {N, T <: AbstractGPUArray} = Base.BroadcastStyle(T)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? The number of dimensions for T would be <N.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some confusion: The N here is what's called M or N+1 elsewhere in this file. The 4th type parameter, not the 3rd. So this line does change the result, although maybe it's not tested?

Is there a way around that?

One hack I can think of is relying on all ArrayStyles to be like CuArrayStyle{N} and have their N first. It seems pretty unlikely that any AbstractGPUArray is going to deviate from that pattern.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That or some hack to swap out the ndims of the index array type to N. Then forwarding to BroadcastStyle would just do the right thing. I'm not sure which is more stable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See what you think of ffb6f00

@ToucheSir
Copy link
Member

Last I checked, JLArray was not terribly comprehensive and didn't fulfill the testing requirements for the PR I was considering it for. That said, it may be better than nothing.

@mcabbott
Copy link
Member Author

Ok, I have not run into these limitations yet.

It does appear to catch some problems here, e.g. FluxML/Flux.jl#1905 .

It also causes some problems, as Zygote has special paths for CuArray not AbstractGPUArray. But perhaps that should be changed.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's get GPU CI set up next

@mcabbott
Copy link
Member Author

I don't have a merge button here, BTW.

@ToucheSir ToucheSir merged commit 621c3aa into FluxML:main Mar 12, 2022
@ToucheSir
Copy link
Member

Looks like the default permissions are pretty strict, you should now :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants