Skip to content

Commit

Permalink
Merge pull request #9 from mcabbott/gpuarrays
Browse files Browse the repository at this point in the history
RFC: use GPUArrays, also for testing
  • Loading branch information
ToucheSir authored Mar 12, 2022
2 parents 376f6ed + cea26a4 commit 621c3aa
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 12 deletions.
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -14,11 +14,17 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Adapt = "3.0"
CUDA = "3.8"
ChainRulesCore = "1.13"
GPUArrays = "8.2.1"
MLUtils = "0.2"
NNlib = "0.8"
Zygote = "0.6.35"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Test", "CUDA", "Random", "Zygote"]
2 changes: 1 addition & 1 deletion src/OneHotArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OneHotArrays

using Adapt
using ChainRulesCore
using CUDA
using GPUArrays
using LinearAlgebra
using MLUtils
using NNlib
Expand Down
24 changes: 17 additions & 7 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
end

# copy CuArray versions back before trying to print them:
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
Base.print_array(io, adapt(Array, X))
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
Base.print_array(io, adapt(Array, X))
for fun in (:show, :print_array) # print_array is used by 3-arg show
@eval begin
Base.$fun(io::IO, X::OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}) where {T, L, N, var"N+1"} =
Base.$fun(io, adapt(Array, X))
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, <:Any, <:AbstractGPUArray}}) where {T, L, N} =
Base.$fun(io, adapt(Array, X))
end
end

_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}}) where {var"N+1"} = Array{Bool, var"N+1"}
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:AbstractGPUArray}) where {var"N+1"} = AbstractGPUArray{Bool, var"N+1"}

function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
Expand All @@ -90,7 +94,13 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()
function Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
S = Base.BroadcastStyle(T)
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
# isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
(typeof(S).name.wrapper){var"N+1"}()
end

Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

Expand Down
4 changes: 2 additions & 2 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ nonzero elements.
If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
if `default` is given, else an error.
If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
`AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
If `xs` has more dimensions, `N = ndims(xs) > 1`, then the result is an
`AbstractArray{Bool, N+1}` which is one-hot along the first dimension,
i.e. `result[:, k...] == onehot(xs[k...], labels)`.
Note that `xs` can be any iterable, such as a string. And that using a tuple
Expand Down
51 changes: 51 additions & 0 deletions test/gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@

# Tests from Flux, probably not the optimal testset organisation!

@testset "CUDA" begin
x = randn(5, 5)
cx = cu(x)
@test cx isa CuArray

@test_broken onecold(cu([1.0, 2.0, 3.0])) == 3 # scalar indexing error?

x = onehotbatch([1, 2, 3], 1:3)
cx = cu(x)
@test cx isa OneHotMatrix && cx.indices isa CuArray
@test (cx .+ 1) isa CuArray

xs = rand(5, 5)
ys = onehotbatch(1:5,1:5)
@test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)
end

@testset "onehot gpu" begin
y = onehotbatch(ones(3), 1:2) |> cu;
@test (repr("text/plain", y); true)

gA = rand(3, 2) |> cu;
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
end

@testset "onecold gpu" begin
y = onehotbatch(ones(3), 1:10) |> cu;
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
@test onecold(y) isa CuArray
@test y[3,:] isa CuArray
@test onecold(y, l) == ['a', 'a', 'a']
end

@testset "onehot forward map to broadcast" begin
oa = OneHotArray(rand(1:10, 5, 5), 10) |> cu
@test all(map(identity, oa) .== oa)
@test all(map(x -> 2 * x, oa) .== 2 .* oa)
end

@testset "show gpu" begin
x = onehotbatch([1, 2, 3], 1:3)
cx = cu(x)
# 3-arg show
@test contains(repr("text/plain", cx), "1 ⋅ ⋅")
@test contains(repr("text/plain", cx), string(typeof(cx.indices)))
# 2-arg show, https://github.com/FluxML/Flux.jl/issues/1905
@test repr(cx) == "Bool[1 0 0; 0 1 0; 0 0 1]"
end
25 changes: 25 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,28 @@ end
@testset "Linear Algebra" begin
include("linalg.jl")
end

using Zygote
import CUDA
if CUDA.functional()
using CUDA # exports CuArray, etc
@info "starting CUDA tests"
else
@info "CUDA not functional, testing via GPUArrays"
using GPUArrays
GPUArrays.allowscalar(false)

# GPUArrays provides a fake GPU array, for testing
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
using Random # loaded within jl_file
include(jl_file)
using .JLArrays
cu = jl
CuArray{T,N} = JLArray{T,N}
end

@test cu(rand(3)) .+ 1 isa CuArray

@testset "GPUArrays" begin
include("gpu.jl")
end

0 comments on commit 621c3aa

Please sign in to comment.