Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: use GPUArrays, also for testing #9

Merged
merged 4 commits into from
Mar 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -14,11 +14,17 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Adapt = "3.0"
CUDA = "3.8"
ChainRulesCore = "1.13"
GPUArrays = "8.2.1"
MLUtils = "0.2"
NNlib = "0.8"
Zygote = "0.6.35"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Test", "CUDA", "Random", "Zygote"]
2 changes: 1 addition & 1 deletion src/OneHotArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OneHotArrays

using Adapt
using ChainRulesCore
using CUDA
using GPUArrays
using LinearAlgebra
using MLUtils
using NNlib
Expand Down
24 changes: 17 additions & 7 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
end

# copy CuArray versions back before trying to print them:
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
Base.print_array(io, adapt(Array, X))
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
Base.print_array(io, adapt(Array, X))
for fun in (:show, :print_array) # print_array is used by 3-arg show
@eval begin
Base.$fun(io::IO, X::OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}) where {T, L, N, var"N+1"} =
Base.$fun(io, adapt(Array, X))
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, <:Any, <:AbstractGPUArray}}) where {T, L, N} =
Base.$fun(io, adapt(Array, X))
end
end

_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}}) where {var"N+1"} = Array{Bool, var"N+1"}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:AbstractGPUArray}) where {var"N+1"} = AbstractGPUArray{Bool, var"N+1"}

function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
Expand All @@ -90,7 +94,13 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()
function Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
S = Base.BroadcastStyle(T)
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
# isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
(typeof(S).name.wrapper){var"N+1"}()
end

Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

Expand Down
4 changes: 2 additions & 2 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ nonzero elements.
If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
if `default` is given, else an error.

If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
`AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
If `xs` has more dimensions, `N = ndims(xs) > 1`, then the result is an
`AbstractArray{Bool, N+1}` which is one-hot along the first dimension,
i.e. `result[:, k...] == onehot(xs[k...], labels)`.

Note that `xs` can be any iterable, such as a string. And that using a tuple
Expand Down
51 changes: 51 additions & 0 deletions test/gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@

# Tests from Flux, probably not the optimal testset organisation!

@testset "CUDA" begin
x = randn(5, 5)
cx = cu(x)
@test cx isa CuArray

@test_broken onecold(cu([1.0, 2.0, 3.0])) == 3 # scalar indexing error?

x = onehotbatch([1, 2, 3], 1:3)
cx = cu(x)
@test cx isa OneHotMatrix && cx.indices isa CuArray
@test (cx .+ 1) isa CuArray

xs = rand(5, 5)
ys = onehotbatch(1:5,1:5)
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
end

@testset "onehot gpu" begin
y = onehotbatch(ones(3), 1:2) |> cu;
@test (repr("text/plain", y); true)

gA = rand(3, 2) |> cu;
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
end

@testset "onecold gpu" begin
y = onehotbatch(ones(3), 1:10) |> cu;
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
@test onecold(y) isa CuArray
@test y[3,:] isa CuArray
@test onecold(y, l) == ['a', 'a', 'a']
end

@testset "onehot forward map to broadcast" begin
oa = OneHotArray(rand(1:10, 5, 5), 10) |> cu
@test all(map(identity, oa) .== oa)
@test all(map(x -> 2 * x, oa) .== 2 .* oa)
end

@testset "show gpu" begin
x = onehotbatch([1, 2, 3], 1:3)
cx = cu(x)
# 3-arg show
@test contains(repr("text/plain", cx), "1 ⋅ ⋅")
@test contains(repr("text/plain", cx), string(typeof(cx.indices)))
# 2-arg show, https://github.com/FluxML/Flux.jl/issues/1905
@test repr(cx) == "Bool[1 0 0; 0 1 0; 0 0 1]"
end
25 changes: 25 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,28 @@ end
@testset "Linear Algebra" begin
include("linalg.jl")
end

using Zygote
import CUDA
if CUDA.functional()
using CUDA # exports CuArray, etc
@info "starting CUDA tests"
else
@info "CUDA not functional, testing via GPUArrays"
using GPUArrays
GPUArrays.allowscalar(false)

# GPUArrays provides a fake GPU array, for testing
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
using Random # loaded within jl_file
include(jl_file)
using .JLArrays
cu = jl
CuArray{T,N} = JLArray{T,N}
end

@test cu(rand(3)) .+ 1 isa CuArray

@testset "GPUArrays" begin
include("gpu.jl")
end