Skip to content

Commit

Permalink
BroadcastStyle{N+1}
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Mar 10, 2022
1 parent ac20bd3 commit ffb6f00
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,13 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, T}}) where {N, T <: AbstractGPUArray} = Base.BroadcastStyle(T)
function Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
S = Base.BroadcastStyle(T)
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
# isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
(typeof(S).name.wrapper){var"N+1"}()
end

Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

Expand Down

0 comments on commit ffb6f00

Please sign in to comment.