-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
257 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
module LuxCUDAMPINCCLExt | ||
|
||
import Lux: MPIBackend, NCCLBackend, DistributedUtils | ||
import MPI | ||
import NCCL | ||
import Setfield: @set! | ||
|
||
function DistributedUtils.__initialize( | ||
::Val{:NCCL}; cuda_devices=nothing, amdgpu_devices=nothing) | ||
DistributedUtils.NCCL_Initialized[] = true | ||
DistributedUtils.__initialize(Val(:MPI); cuda_devices, amdgpu_devices) | ||
return | ||
end | ||
|
||
function DistributedUtils.get_distributed_backend(::Val{:NCCL}) | ||
unique_id = NCCL.UniqueID() # Generate on all ranks to know the type | ||
mpi_backend = DistributedUtils.get_distributed_backend(Val(:MPI)) | ||
buf = [unique_id.internal...] | ||
DistributedUtils.bcast!(mpi_backend, buf; root=0) | ||
@set! unique_id.internal = Tuple(buf) | ||
|
||
nranks = DistributedUtils.total_workers(mpi_backend) | ||
rank = DistributedUtils.local_rank(mpi_backend) | ||
|
||
return NCCLBackend(NCCL.Communicator(nranks, rank; unique_id)) | ||
end | ||
|
||
DistributedUtils.local_rank(backend::NCCLBackend) = NCCL.rank(backend.comm) | ||
|
||
DistributedUtils.total_workers(backend::NCCLBackend) = NCCL.size(backend.comm) | ||
|
||
function DistributedUtils.bcast!(backend::NCCLBackend, sendrecvbuf; root=0) | ||
NCCL.Broadcast!(sendrecvbuf, backend.comm; root) | ||
return sendrecvbuf | ||
end | ||
|
||
function DistributedUtils.bcast!(backend::NCCLBackend, sendbuf, recvbuf; root=0) | ||
NCCL.Broadcast!(sendbuf, recvbuf, backend.comm; root) | ||
return recvbuf | ||
end | ||
|
||
end |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
module LuxLuxCUDAExt | ||
|
||
import Lux | ||
import LuxCUDA: CUDA | ||
|
||
Lux.__is_extension_loaded(::Val{:LuxCUDA}) = Val(true) | ||
|
||
Lux.__set_device!(::Val{:CUDA}, id::Int) = CUDA.functional() && CUDA.device!(id) | ||
function Lux.__set_device!(::Val{:CUDA}, ::Nothing, rank::Int) | ||
CUDA.functional() || return | ||
CUDA.device!(rank % length(CUDA.devices())) | ||
return | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
module LuxMPIExt | ||
|
||
import Lux: MPIBackend, NCCLBackend, DistributedUtils, __is_extension_loaded, __set_device!, | ||
__unwrap_val | ||
import MPI | ||
|
||
function DistributedUtils.__initialize( | ||
::Val{:MPI}; cuda_devices=nothing, amdgpu_devices=nothing) | ||
!MPI.Initialized() && MPI.Init() | ||
DistributedUtils.MPI_Initialized[] = true | ||
|
||
local_rank = MPI.Comm_rank(MPI.COMM_WORLD) | ||
|
||
if cuda_devices !== missing && __unwrap_val(__is_extension_loaded(Val(:LuxCUDA))) | ||
if cuda_devices === nothing | ||
__set_device!(Val(:CUDA), nothing, local_rank) | ||
else | ||
__set_device!(Val(:CUDA), cuda_devices[local_rank]) | ||
end | ||
end | ||
|
||
if amdgpu_devices !== missing && __unwrap_val(__is_extension_loaded(Val(:LuxAMDGPU))) | ||
if amdgpu_devices === nothing | ||
__set_device!(Val(:AMDGPU), nothing, local_rank) | ||
else | ||
__set_device!(Val(:AMDGPU), amdgpu_devices[local_rank]) | ||
end | ||
end | ||
|
||
return | ||
end | ||
|
||
DistributedUtils.get_distributed_backend(::Val{:MPI}) = MPIBackend(MPI.COMM_WORLD) | ||
|
||
DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm) | ||
|
||
DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm) | ||
|
||
function DistributedUtils.bcast!(backend::MPIBackend, sendrecvbuf; root=0) | ||
MPI.Bcast!(sendrecvbuf, backend.comm; root) | ||
return sendrecvbuf | ||
end | ||
|
||
function DistributedUtils.bcast!(backend::MPIBackend, sendbuf, recvbuf; root=0) | ||
MPI.Bcast!(sendbuf, recvbuf, backend.comm; root) | ||
return recvbuf | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,23 @@ | ||
abstract type AbstractLuxDistributedTrainingBackend end | ||
abstract type AbstractLuxDistributedBackend end | ||
|
||
struct NCCLBackend <: AbstractLuxDistributedTrainingBackend | ||
function NCCLBackend() | ||
if Base.get_extension(@__MODULE__, :LuxCUDANCCLExt) === nothing | ||
error("`NCCLBackend` requires `CUDA.jl` and `NCCL.jl` to be loaded") | ||
struct NCCLBackend{C} <: AbstractLuxDistributedBackend | ||
comm::C | ||
|
||
function NCCLBackend(comm=nothing) | ||
if Base.get_extension(@__MODULE__, :LuxCUDAMPINCCLExt) === nothing | ||
error("`NCCLBackend` requires `CUDA.jl`, `MPI.jl` and `NCCL.jl` to be loaded.") | ||
end | ||
return new{typeof(comm)}(comm) | ||
end | ||
end | ||
|
||
struct MPIBackend{C} <: AbstractLuxDistributedBackend | ||
comm::C | ||
|
||
function MPIBackend(comm=nothing) | ||
if Base.get_extension(@__MODULE__, :LuxMPIExt) === nothing | ||
error("`MPIBackend` requires `MPI.jl` to be loaded.") | ||
end | ||
return new() | ||
return new{typeof(comm)}(comm) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
module DistributedUtils | ||
|
||
import ChainRulesCore as CRC | ||
import Functors: fmap | ||
import ..Lux: AbstractLuxDistributedBackend, MPIBackend, NCCLBackend | ||
import Optimisers: Leaf | ||
import Setfield: @set! | ||
|
||
const NCCL_Initialized = Ref(false) | ||
const MPI_Initialized = Ref(false) | ||
|
||
""" | ||
initialized(backend::Val) | ||
Check if the given backend is initialized. | ||
""" | ||
initialized(::Val{:MPI}) = MPI_Initialized[] | ||
initialized(::Val{:NCCL}) = NCCL_Initialized[] | ||
|
||
function initialize(backend::Val; cuda_devices=nothing, amdgpu_devices=nothing) | ||
initialized(backend) && return | ||
__initialize(backend; cuda_devices, amdgpu_devices) | ||
return | ||
end | ||
|
||
function __initialize end | ||
|
||
""" | ||
get_distributed_backend(backend::Val) | ||
Get the distributed backend for the given backend type. Possible values are: | ||
- `Val(:MPI)`: MPI backend for distributed training. Requires `MPI.jl` to be installed. | ||
- `Val(:NCCL)`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`, | ||
`MPI.jl`, and `NCCL.jl` to be installed. | ||
""" | ||
function get_distributed_backend end | ||
|
||
CRC.@non_differentiable get_distributed_backend(::Any...) | ||
|
||
""" | ||
local_rank(backend::AbstractLuxDistributedBackend) | ||
Get the local rank for the given backend. | ||
""" | ||
function local_rank end | ||
|
||
CRC.@non_differentiable local_rank(::Any...) | ||
|
||
""" | ||
total_workers(backend::AbstractLuxDistributedBackend) | ||
Get the total number of workers for the given backend. | ||
""" | ||
function total_workers end | ||
|
||
CRC.@non_differentiable total_workers(::Any...) | ||
|
||
function bcast! end | ||
|
||
CRC.@non_differentiable bcast!(::Any...) | ||
|
||
function allreduce! end | ||
|
||
CRC.@non_differentiable allreduce!(::Any...) | ||
|
||
function reduce! end | ||
|
||
CRC.@non_differentiable reduce!(::Any...) | ||
|
||
# syncronize! | ||
""" | ||
synchronize!!(backend::AbstractLuxDistributedBackend, ps; root::Int=0) | ||
Synchronize the given structure `ps` using the given backend. The value at `root` will be | ||
broadcasted to all other workers. | ||
""" | ||
function synchronize!!(backend::AbstractLuxDistributedBackend, ps::Tuple; root::Int=0) | ||
length(ps) == 0 && return ps | ||
return map(x -> synchronize!!(backend, x; root), ps) | ||
end | ||
|
||
function synchronize!!(backend::AbstractLuxDistributedBackend, | ||
ps::NamedTuple{fields}; root::Int=0) where {fields} | ||
length(ps) == 0 && return ps | ||
return NamedTuple{fields}(map(x -> synchronize!!(backend, x; root), values(ps))) | ||
end | ||
|
||
function synchronize!!( | ||
backend::AbstractLuxDistributedBackend, ps::AbstractArray{T}; root::Int=0) where {T} | ||
if isbitstype(T) | ||
bcast!(backend, ps; root) | ||
return ps | ||
end | ||
return map(x -> synchronize!!(backend, x; root), ps) | ||
end | ||
|
||
function synchronize!!(backend::AbstractLuxDistributedBackend, ps::Leaf; root::Int=0) | ||
@set! ps.state = synchronize!!(backend, ps.state; root) | ||
return ps | ||
end | ||
|
||
function synchronize!!(backend::AbstractLuxDistributedBackend, ps::T; root::Int=0) where {T} | ||
isbitstype(T) && return bcast!(backend, [ps]; root)[] | ||
return ps # If we don't know how to synchronize, just return the value. For ex, Symbol, String, etc. | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -216,3 +216,5 @@ end | |
@inline _pairs(x) = pairs(x) | ||
|
||
@inline __value(x) = x | ||
|
||
function __set_device! end |