Skip to content

Commit

Permalink
reorganize gpu movement code
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Mar 30, 2024
1 parent 61d2b7e commit 27a5f68
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 46 deletions.
2 changes: 1 addition & 1 deletion ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ function (device::Flux.FluxAMDGPUDevice)(x)
return Flux.gpu(Flux.FluxAMDGPUAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer
end
end

Flux._get_device_name(::Flux.FluxAMDGPUDevice) = "AMDGPU"
Flux._isavailable(::Flux.FluxAMDGPUDevice) = true
Flux._isfunctional(::Flux.FluxAMDGPUDevice) = AMDGPU.functional()
Expand Down Expand Up @@ -55,7 +56,6 @@ include("conv.jl")

function __init__()
Flux.AMDGPU_LOADED[] = true
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMDGPU"]] = AMDGPU.functional() ? Flux.FluxAMDGPUDevice(AMDGPU.device()) : Flux.FluxAMDGPUDevice(nothing)
end

# TODO
Expand Down
18 changes: 12 additions & 6 deletions ext/FluxAMDGPUExt/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,16 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMDGPU_CONV)
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
end

function Flux.get_device(::Val{:AMDGPU}, id::Int) # id should start from 0
old_id = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ids start from 0
AMDGPU.device!(AMDGPU.devices()[id + 1]) # adding 1 because ids start from 0
device = Flux.FluxAMDGPUDevice(AMDGPU.device())
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
return device
function Flux.get_device(::Val{:AMDGPU}, id::Int = -1)
if id < 0 # return current device
AMDGPU.functional() ? Flux.FluxAMDGPUDevice(AMDGPU.device()) : Flux.FluxAMDGPUDevice(nothing)
return Flux.FluxAMDGPUDevice(AMDGPU.device())
else
# id should start from 0
old_id = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ids start from 0
AMDGPU.device!(AMDGPU.devices()[id + 1]) # adding 1 because ids start from 0
device = Flux.FluxAMDGPUDevice(AMDGPU.device())
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
return device
end
end
3 changes: 0 additions & 3 deletions ext/FluxCUDAExt/FluxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ include("utils.jl")
function __init__()
Flux.CUDA_LOADED[] = true

## add device to available devices
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CUDA"]] = CUDA.functional() ? Flux.FluxCUDADevice(CUDA.device()) : Flux.FluxCUDADevice(nothing)

try
Base.require(Main, :cuDNN)
catch
Expand Down
7 changes: 6 additions & 1 deletion ext/FluxCUDAExt/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ function _cuda(id::Union{Nothing, Int}, x)
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(id), x), x; exclude=Flux._isleaf)
end

function Flux.get_device(::Val{:CUDA}, id::Int)
function Flux.get_device(::Val{:CUDA}, id::Int = -1)
if id < 0 # return current device
# return CUDA.functional() ? Flux.FluxCUDADevice(CUDA.device()) : Flux.FluxCUDADevice(nothing)
return Flux.FluxCUDADevice(CUDA.device())
else
old_id = CUDA.device().handle
CUDA.device!(id)
device = Flux.FluxCUDADevice(CUDA.device())
CUDA.device!(old_id)
return device
end
end
1 change: 0 additions & 1 deletion ext/FluxMetalExt/FluxMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ include("functor.jl")

function __init__()
Flux.METAL_LOADED[] = true
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]] = Metal.functional() ? Flux.FluxMetalDevice(Metal.current_device()) : Flux.FluxMetalDevice(nothing)
end

end
5 changes: 2 additions & 3 deletions ext/FluxMetalExt/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ function _metal(x)
fmap(x -> Adapt.adapt(FluxMetalAdaptor(), x), x; exclude=_isleaf)
end

function Flux.get_device(::Val{:Metal}, id::Int)
function Flux.get_device(::Val{:Metal}, id::Int = 0)
@assert id == 0 "Metal backend only supports one device at the moment"
return Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]
return Metal.functional() ? Flux.FluxMetalDevice(Metal.current_device()) : Flux.FluxMetalDevice(nothing)
end

56 changes: 40 additions & 16 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ const GPU_BACKENDS = ["CUDA", "AMDGPU", "Metal", "CPU"]
const GPU_BACKEND_ORDER = Dict(collect(zip(GPU_BACKENDS, 1:length(GPU_BACKENDS))))
const GPU_BACKEND = @load_preference("gpu_backend", "CUDA")

"""
gpu_backend!(backend::String)
Sets preferences for the GPU backend to `backend`.
The backend must be one of `"CUDA"`, `"AMDGPU"`, `"Metal"`, or `"CPU"`.
The preference will be saved in the preferences file and will be used in subsequent Julia sessions
when calling [`Flux.get_device`](@ref) or [`Flux.gpu`](@ref).
See also [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) for more information.
"""
function gpu_backend!(backend::String)
if backend == GPU_BACKEND
@info """
Expand Down Expand Up @@ -230,7 +240,10 @@ See also [`f32`](@ref) and [`f16`](@ref) to change element type only.
See the [CUDA.jl docs](https://juliagpu.github.io/CUDA.jl/stable/usage/multigpu/)
to help identify the current device.
# Example
See also [`get_device`](@ref) for more fine-grained control over device selection.
# Examples
```julia-repl
julia> m = Dense(rand(2, 3)) # constructed with Float64 weight matrix
Dense(3 => 2) # 8 parameters
Expand Down Expand Up @@ -258,7 +271,7 @@ function gpu(x)
elseif GPU_BACKEND == "CPU"
cpu(x)
else
error("""
error(lazy"""
Unsupported GPU backend: $GPU_BACKEND.
Supported backends are: $GPU_BACKENDS.
""")
Expand Down Expand Up @@ -517,26 +530,26 @@ _get_device_name(::FluxCPUDevice) = "CPU"
A type representing `device` objects for the `"CUDA"` backend for Flux.
"""
Base.@kwdef struct FluxCUDADevice <: AbstractDevice
deviceID
Base.@kwdef struct FluxCUDADevice{D} <: AbstractDevice
deviceID::D
end

"""
FluxAMDGPUDevice <: AbstractDevice
A type representing `device` objects for the `"AMDGPU"` backend for Flux.
"""
Base.@kwdef struct FluxAMDGPUDevice <: AbstractDevice
deviceID
Base.@kwdef struct FluxAMDGPUDevice{D} <: AbstractDevice
deviceID::D
end

"""
FluxMetalDevice <: AbstractDevice
A type representing `device` objects for the `"Metal"` backend for Flux.
"""
Base.@kwdef struct FluxMetalDevice <: AbstractDevice
deviceID
Base.@kwdef struct FluxMetalDevice{D} <: AbstractDevice
deviceID::D
end

## get device
Expand Down Expand Up @@ -660,10 +673,11 @@ end
"""
Flux.get_device(backend::String, [idx])::Flux.AbstractDevice
Get a device object for a backend specified by the string `backend` and `idx`. The currently supported values
of `backend` are `"CUDA"`, `"AMDGPU"` and `"CPU"`.
Get a device object for a backend specified by the string `backend` and `idx`.
The currently supported values of `backend` are `"CUDA"`, `"AMDGPU"`, `"Metal"`, and `"CPU"`.
`idx` must be an integer value between `0` and the number of available devices.
If not given, will be automatically
If `idx` is negative or not provided, the first available device is returned.
# Examples
Expand Down Expand Up @@ -693,21 +707,31 @@ julia> cpu_device = Flux.get_device("CPU")
```
"""
function get_device(backend::String; idx=-1)
if backend == "CPU"
return FluxCPUDevice()
elseif idx >= 0
function get_device(backend::String, idx=-1)
if idx >= 0
return get_device(Val(Symbol(backend)), idx)
else
return get_device(Val(Symbol(backend)))
end
end

get_device(::Val{:CPU}, idx = 0) = FluxCPUDevice()

# Fallback
function get_device(::Val{D}, idx) where D
function get_device(::Val{D}, idx = 0) where D
if D (:CUDA, :AMDGPU, :Metal)
error("Unavailable backend: $(D). Try importing the corresponding package with `using $D`.")
else
error("Unsupported backend: $(D). Supported backends are $(GPU_BACKENDS).")
end
end

function has_device(backend::String, idx=-1)
if idx >= 0
return has_device(Val(Symbol(backend)), idx)
else
return has_device(Val(Symbol(backend)))
end
end

has_device(::Val{:CPU}, idx = 0) = true
2 changes: 1 addition & 1 deletion test/ext_amdgpu/get_devices.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
amdgpu_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMDGPU"]]
amdgpu_device = Flux.get_device("AMDGPU")

# should pass, whether or not AMDGPU is functional
@test typeof(amdgpu_device) <: Flux.FluxAMDGPUDevice
Expand Down
2 changes: 1 addition & 1 deletion test/ext_cuda/get_devices.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cuda_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CUDA"]]
cuda_device = get_device("CUDA")

# should pass, whether or not CUDA is functional
@test typeof(cuda_device) <: Flux.FluxCUDADevice
Expand Down
8 changes: 0 additions & 8 deletions test/ext_metal/get_devices.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
@testset "Flux.DEVICES" begin
metal_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]

# should pass, whether or not Metal is functional
@test typeof(metal_device) <: Flux.FluxMetalDevice

@test typeof(metal_device.deviceID) <: Metal.MTLDevice
end

@testset "get_device()" begin
metal_device = Flux.get_device()
Expand Down
5 changes: 0 additions & 5 deletions test/functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ if !(Flux.CUDA_LOADED[] || Flux.AMDGPU_LOADED[] || Flux.METAL_LOADED[])
@test x === gpu(x)
end

@test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CUDA"]]) <: Nothing
@test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMDGPU"]]) <: Nothing
@test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]) <: Nothing
@test typeof(Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CPU"]]) <: Flux.FluxCPUDevice

dev = Flux.get_device()
@test typeof(dev) <: Flux.FluxCPUDevice
@test dev(x) == x
Expand Down

0 comments on commit 27a5f68

Please sign in to comment.