Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require GPUArrays instead of CUDA #1182

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,43 +251,44 @@ end
return y, bc_fwd_back
end

@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
@init @require GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" begin

const CuArrayStyle = CUDA.AbstractGPUArrayStyle
# const CuArrayStyle = AbstractGPUArrayStyle

if isdefined(CUDA, :cufunc) # CUDA < 3.0
# if isdefined(CUDA, :cufunc) # CUDA < 3.0

@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)
# @eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
# broadcast_forward(CUDA.cufunc(f), args...)

else # CUDA >= 3.0 -- don't need cufunc(f).
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
# so perhaps this can be deleted? Possible edge case here:
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
# else # CUDA >= 3.0 -- don't need cufunc(f).
# # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
# # so perhaps this can be deleted? Possible edge case here:
# # https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415

@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(f, args...)
Comment on lines +263 to -269
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it's safe to delete these.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, changing CuArrayStyle -> AbstractGPUArrayStyle seems pretty safe.

# @eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
# broadcast_forward(f, args...)

end
# end

@adjoint (::Type{T})(xs::Array) where {T <: CUDA.CuArray} =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good opportunity to change these over to rrule(::ZygoteRuleConfig, ...) instead? IME they're much easier to debug than @adjoint definitions.

@adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} =
T(xs), Δ -> (convert(Array, Δ), )

@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
@adjoint function sum(xs::AbstractGPUArray; dims = :)
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end

# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end

@eval pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
@eval pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = GPUArrays.@allowscalar Δ[sz]

end