Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: use GPUArrays, also for testing #9

Merged
merged 4 commits into from
Mar 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -14,11 +14,17 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Adapt = "3.0"
CUDA = "3.8"
ChainRulesCore = "1.13"
GPUArrays = "8.2.1"
MLUtils = "0.2"
NNlib = "0.8"
Zygote = "0.6.35"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Test", "CUDA", "Random", "Zygote"]
2 changes: 1 addition & 1 deletion src/OneHotArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OneHotArrays

using Adapt
using ChainRulesCore
using CUDA
using GPUArrays
using LinearAlgebra
using MLUtils
using NNlib
Expand Down
8 changes: 4 additions & 4 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
end

# copy CuArray versions back before trying to print them:
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}) where {T, L, N, var"N+1"} =
Base.print_array(io, adapt(Array, X))
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}}) where {T, L, N, var"N+1"} =
Base.print_array(io, adapt(Array, X))

_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:AbstractGPUArray}) where N = AbstractGPUArray{Bool, N}

function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
Expand All @@ -90,7 +90,7 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, T}}) where {N, T <: AbstractGPUArray} = Base.BroadcastStyle(T)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? The number of dimensions for T would be <N.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some confusion: The N here is what's called M or N+1 elsewhere in this file. The 4th type parameter, not the 3rd. So this line does change the result, although maybe it's not tested?

Is there a way around that?

One hack I can think of is relying on all ArrayStyles to be like CuArrayStyle{N} and have their N first. It seems pretty unlikely that any AbstractGPUArray is going to deviate from that pattern.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That or some hack to swap out the ndims of the index array type to N. Then forwarding to BroadcastStyle would just do the right thing. I'm not sure which is more stable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See what you think of ffb6f00


Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

Expand Down
41 changes: 41 additions & 0 deletions test/gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

# Tests from Flux, probably not the optimal testset organisation!

@testset "CUDA" begin
x = randn(5, 5)
cx = cu(x)
@test cx isa CuArray

@test_broken onecold(cu([1.0, 2.0, 3.0])) == 3 # scalar indexing error?

x = onehotbatch([1, 2, 3], 1:3)
cx = cu(x)
@test cx isa OneHotMatrix && cx.indices isa CuArray
@test (cx .+ 1) isa CuArray

xs = rand(5, 5)
ys = onehotbatch(1:5,1:5)
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
end

@testset "onehot gpu" begin
y = onehotbatch(ones(3), 1:2) |> cu;
@test (repr("text/plain", y); true)

gA = rand(3, 2) |> cu;
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
end

@testset "onecold gpu" begin
y = onehotbatch(ones(3), 1:10) |> cu;
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
@test onecold(y) isa CuArray
@test y[3,:] isa CuArray
@test onecold(y, l) == ['a', 'a', 'a']
end

@testset "onehot forward map to broadcast" begin
oa = OneHotArray(rand(1:10, 5, 5), 10) |> cu
@test all(map(identity, oa) .== oa)
@test all(map(x -> 2 * x, oa) .== 2 .* oa)
end
25 changes: 25 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,28 @@ end
@testset "Linear Algebra" begin
include("linalg.jl")
end

using Zygote
import CUDA
if CUDA.functional()
using CUDA # exports CuArray, etc
@info "starting CUDA tests"
else
@info "CUDA not functional, testing via GPUArrays"
using GPUArrays
GPUArrays.allowscalar(false)

# GPUArrays provides a fake GPU array, for testing
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
using Random # loaded within jl_file
include(jl_file)
using .JLArrays
cu = jl
CuArray{T,N} = JLArray{T,N}
end

@test cu(rand(3)) .+ 1 isa CuArray

@testset "GPUArrays" begin
include("gpu.jl")
end