Skip to content

Commit

Permalink
Simplify setting up and uniform API
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 25, 2024
1 parent 731a606 commit cad7299
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 22 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,22 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LuxCUDANCCLExt = ["CUDA", "NCCL"]
LuxCUDAMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxChainRulesExt = "ChainRules"
LuxComponentArraysExt = "ComponentArrays"
LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"]
LuxFluxTransformExt = "Flux"
LuxLuxAMDGPUExt = "LuxAMDGPU"
LuxLuxCUDAExt = "LuxCUDA"
LuxMPIExt = "MPI"
LuxReverseDiffExt = "ReverseDiff"
LuxTrackerExt = "Tracker"
LuxZygoteExt = "Zygote"
Expand Down
42 changes: 42 additions & 0 deletions ext/LuxCUDAMPINCCLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module LuxCUDAMPINCCLExt

import Lux: MPIBackend, NCCLBackend, DistributedUtils
import MPI
import NCCL
import Setfield: @set!

function DistributedUtils.__initialize(
::Val{:NCCL}; cuda_devices=nothing, amdgpu_devices=nothing)
DistributedUtils.NCCL_Initialized[] = true
DistributedUtils.__initialize(Val(:MPI); cuda_devices, amdgpu_devices)
return
end

function DistributedUtils.get_distributed_backend(::Val{:NCCL})
unique_id = NCCL.UniqueID() # Generate on all ranks to know the type
mpi_backend = DistributedUtils.get_distributed_backend(Val(:MPI))
buf = [unique_id.internal...]
DistributedUtils.bcast!(mpi_backend, buf; root=0)
@set! unique_id.internal = Tuple(buf)

nranks = DistributedUtils.total_workers(mpi_backend)
rank = DistributedUtils.local_rank(mpi_backend)

return NCCLBackend(NCCL.Communicator(nranks, rank; unique_id))
end

DistributedUtils.local_rank(backend::NCCLBackend) = NCCL.rank(backend.comm)

DistributedUtils.total_workers(backend::NCCLBackend) = NCCL.size(backend.comm)

function DistributedUtils.bcast!(backend::NCCLBackend, sendrecvbuf; root=0)
NCCL.Broadcast!(sendrecvbuf, backend.comm; root)
return sendrecvbuf
end

function DistributedUtils.bcast!(backend::NCCLBackend, sendbuf, recvbuf; root=0)
NCCL.Broadcast!(sendbuf, recvbuf, backend.comm; root)
return recvbuf
end

end
11 changes: 0 additions & 11 deletions ext/LuxCUDANCCLExt.jl

This file was deleted.

12 changes: 11 additions & 1 deletion ext/LuxLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
module LuxLuxAMDGPUExt

using Lux, LuxAMDGPU
import Lux
import LuxAMDGPU: AMDGPU

Lux.__is_extension_loaded(::Val{:LuxAMDGPU}) = Val(true)

Lux.__set_device!(::Val{:AMDGPU}, id::Int) = AMDGPU.functional() && AMDGPU.device!(id)
function Lux.__set_device!(::Val{:AMDGPU}, ::Nothing, rank::Int)
AMDGPU.functional() || return
AMDGPU.device!(rank % length(AMDGPU.devices()))
return
end

# Flux modifies Conv weights while mapping to AMD GPU
function Lux._maybe_flip_conv_weight(x::AMDGPU.AnyROCArray)
Expand Down
15 changes: 15 additions & 0 deletions ext/LuxLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module LuxLuxCUDAExt

import Lux
import LuxCUDA: CUDA

Lux.__is_extension_loaded(::Val{:LuxCUDA}) = Val(true)

Lux.__set_device!(::Val{:CUDA}, id::Int) = CUDA.functional() && CUDA.device!(id)
function Lux.__set_device!(::Val{:CUDA}, ::Nothing, rank::Int)
CUDA.functional() || return
CUDA.device!(rank % length(CUDA.devices()))
return
end

end
49 changes: 49 additions & 0 deletions ext/LuxMPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module LuxMPIExt

import Lux: MPIBackend, NCCLBackend, DistributedUtils, __is_extension_loaded, __set_device!,
__unwrap_val
import MPI

function DistributedUtils.__initialize(
::Val{:MPI}; cuda_devices=nothing, amdgpu_devices=nothing)
!MPI.Initialized() && MPI.Init()
DistributedUtils.MPI_Initialized[] = true

local_rank = MPI.Comm_rank(MPI.COMM_WORLD)

if cuda_devices !== missing && __unwrap_val(__is_extension_loaded(Val(:LuxCUDA)))
if cuda_devices === nothing
__set_device!(Val(:CUDA), nothing, local_rank)
else
__set_device!(Val(:CUDA), cuda_devices[local_rank])
end
end

if amdgpu_devices !== missing && __unwrap_val(__is_extension_loaded(Val(:LuxAMDGPU)))
if amdgpu_devices === nothing
__set_device!(Val(:AMDGPU), nothing, local_rank)
else
__set_device!(Val(:AMDGPU), amdgpu_devices[local_rank])
end
end

return
end

DistributedUtils.get_distributed_backend(::Val{:MPI}) = MPIBackend(MPI.COMM_WORLD)

DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm)

DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm)

function DistributedUtils.bcast!(backend::MPIBackend, sendrecvbuf; root=0)
MPI.Bcast!(sendrecvbuf, backend.comm; root)
return sendrecvbuf
end

function DistributedUtils.bcast!(backend::MPIBackend, sendbuf, recvbuf; root=0)
MPI.Bcast!(sendbuf, recvbuf, backend.comm; root)
return recvbuf
end

end
5 changes: 4 additions & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ const CRC = ChainRulesCore

const NAME_TYPE = Union{Nothing, String, Symbol}

@inline __is_extension_loaded(x) = Val(false)

# Utilities
include("utils.jl")

Expand All @@ -50,6 +52,7 @@ include("helpers/stateful.jl")

# Distributed Training
include("distributed/backend.jl")
include("distributed/public_api.jl")

# Deprecations
include("deprecated.jl")
Expand All @@ -76,6 +79,6 @@ export f16, f32, f64

export transform, FluxLayer

export NCCLBackend
export MPIBackend, NCCLBackend, DistributedUtils

end
4 changes: 2 additions & 2 deletions src/contrib/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Here is a linear model with bias and activation:
```julia
d_in = 5
d_out = 7
d = @compact(W=randn(d_out, d_in), b=zeros(d_out),act=relu) do x
d = @compact(W=randn(d_out, d_in), b=zeros(d_out), act=relu) do x
y = W * x
return act.(y .+ b)
end
Expand All @@ -72,7 +72,7 @@ n_out = 1
nlayers = 3
model = @compact(w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out),act=relu) do x
w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
Expand Down
25 changes: 19 additions & 6 deletions src/distributed/backend.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
abstract type AbstractLuxDistributedTrainingBackend end
abstract type AbstractLuxDistributedBackend end

struct NCCLBackend <: AbstractLuxDistributedTrainingBackend
function NCCLBackend()
if Base.get_extension(@__MODULE__, :LuxCUDANCCLExt) === nothing
error("`NCCLBackend` requires `CUDA.jl` and `NCCL.jl` to be loaded")
struct NCCLBackend{C} <: AbstractLuxDistributedBackend
comm::C

function NCCLBackend(comm=nothing)
if Base.get_extension(@__MODULE__, :LuxCUDAMPINCCLExt) === nothing
error("`NCCLBackend` requires `CUDA.jl`, `MPI.jl` and `NCCL.jl` to be loaded.")
end
return new{typeof(comm)}(comm)
end
end

struct MPIBackend{C} <: AbstractLuxDistributedBackend
comm::C

function MPIBackend(comm=nothing)
if Base.get_extension(@__MODULE__, :LuxMPIExt) === nothing
error("`MPIBackend` requires `MPI.jl` to be loaded.")
end
return new()
return new{typeof(comm)}(comm)
end
end
108 changes: 108 additions & 0 deletions src/distributed/public_api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
module DistributedUtils

import ChainRulesCore as CRC
import Functors: fmap
import ..Lux: AbstractLuxDistributedBackend, MPIBackend, NCCLBackend
import Optimisers: Leaf
import Setfield: @set!

const NCCL_Initialized = Ref(false)
const MPI_Initialized = Ref(false)

"""
initialized(backend::Val)
Check if the given backend is initialized.
"""
initialized(::Val{:MPI}) = MPI_Initialized[]
initialized(::Val{:NCCL}) = NCCL_Initialized[]

function initialize(backend::Val; cuda_devices=nothing, amdgpu_devices=nothing)
initialized(backend) && return
__initialize(backend; cuda_devices, amdgpu_devices)
return
end

function __initialize end

"""
get_distributed_backend(backend::Val)
Get the distributed backend for the given backend type. Possible values are:
- `Val(:MPI)`: MPI backend for distributed training. Requires `MPI.jl` to be installed.
- `Val(:NCCL)`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`,
`MPI.jl`, and `NCCL.jl` to be installed.
"""
function get_distributed_backend end

CRC.@non_differentiable get_distributed_backend(::Any...)

"""
local_rank(backend::AbstractLuxDistributedBackend)
Get the local rank for the given backend.
"""
function local_rank end

CRC.@non_differentiable local_rank(::Any...)

"""
total_workers(backend::AbstractLuxDistributedBackend)
Get the total number of workers for the given backend.
"""
function total_workers end

CRC.@non_differentiable total_workers(::Any...)

function bcast! end

CRC.@non_differentiable bcast!(::Any...)

function allreduce! end

CRC.@non_differentiable allreduce!(::Any...)

function reduce! end

CRC.@non_differentiable reduce!(::Any...)

# syncronize!
"""
synchronize!!(backend::AbstractLuxDistributedBackend, ps; root::Int=0)
Synchronize the given structure `ps` using the given backend. The value at `root` will be
broadcasted to all other workers.
"""
function synchronize!!(backend::AbstractLuxDistributedBackend, ps::Tuple; root::Int=0)
length(ps) == 0 && return ps
return map(x -> synchronize!!(backend, x; root), ps)
end

function synchronize!!(backend::AbstractLuxDistributedBackend,
ps::NamedTuple{fields}; root::Int=0) where {fields}
length(ps) == 0 && return ps
return NamedTuple{fields}(map(x -> synchronize!!(backend, x; root), values(ps)))
end

function synchronize!!(
backend::AbstractLuxDistributedBackend, ps::AbstractArray{T}; root::Int=0) where {T}
if isbitstype(T)
bcast!(backend, ps; root)
return ps
end
return map(x -> synchronize!!(backend, x; root), ps)
end

function synchronize!!(backend::AbstractLuxDistributedBackend, ps::Leaf; root::Int=0)
@set! ps.state = synchronize!!(backend, ps.state; root)
return ps
end

function synchronize!!(backend::AbstractLuxDistributedBackend, ps::T; root::Int=0) where {T}
isbitstype(T) && return bcast!(backend, [ps]; root)[]
return ps # If we don't know how to synchronize, just return the value. For ex, Symbol, String, etc.
end

end
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,5 @@ end
@inline _pairs(x) = pairs(x)

@inline __value(x) = x

function __set_device! end

0 comments on commit cad7299

Please sign in to comment.