From cad72991aa13544d143228d7bdb3838ca1430a2d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 01:20:17 -0500 Subject: [PATCH] Simplify setting up and uniform API --- Project.toml | 6 +- ext/LuxCUDAMPINCCLExt.jl | 42 +++++++++++++ ext/LuxCUDANCCLExt.jl | 11 ---- ext/LuxLuxAMDGPUExt.jl | 12 +++- ext/LuxLuxCUDAExt.jl | 15 +++++ ext/LuxMPIExt.jl | 49 +++++++++++++++ src/Lux.jl | 5 +- src/contrib/compact.jl | 4 +- src/distributed/backend.jl | 25 ++++++-- src/distributed/public_api.jl | 108 ++++++++++++++++++++++++++++++++++ src/utils.jl | 2 + 11 files changed, 257 insertions(+), 22 deletions(-) create mode 100644 ext/LuxCUDAMPINCCLExt.jl delete mode 100644 ext/LuxCUDANCCLExt.jl create mode 100644 ext/LuxLuxCUDAExt.jl create mode 100644 ext/LuxMPIExt.jl create mode 100644 src/distributed/public_api.jl diff --git a/Project.toml b/Project.toml index 16c06d0758..7cd99ed16f 100644 --- a/Project.toml +++ b/Project.toml @@ -34,18 +34,22 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -LuxCUDANCCLExt = ["CUDA", "NCCL"] +LuxCUDAMPINCCLExt = ["CUDA", "MPI", "NCCL"] LuxChainRulesExt = "ChainRules" LuxComponentArraysExt = "ComponentArrays" LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] LuxFluxTransformExt = "Flux" LuxLuxAMDGPUExt = "LuxAMDGPU" +LuxLuxCUDAExt = "LuxCUDA" +LuxMPIExt = "MPI" LuxReverseDiffExt = "ReverseDiff" LuxTrackerExt = "Tracker" LuxZygoteExt = "Zygote" diff --git a/ext/LuxCUDAMPINCCLExt.jl b/ext/LuxCUDAMPINCCLExt.jl new file mode 100644 index 0000000000..47e89c4a9a --- /dev/null +++ b/ext/LuxCUDAMPINCCLExt.jl @@ -0,0 +1,42 @@ +module LuxCUDAMPINCCLExt + +import Lux: MPIBackend, NCCLBackend, DistributedUtils +import MPI +import NCCL +import Setfield: @set! + +function DistributedUtils.__initialize( + ::Val{:NCCL}; cuda_devices=nothing, amdgpu_devices=nothing) + DistributedUtils.NCCL_Initialized[] = true + DistributedUtils.__initialize(Val(:MPI); cuda_devices, amdgpu_devices) + return +end + +function DistributedUtils.get_distributed_backend(::Val{:NCCL}) + unique_id = NCCL.UniqueID() # Generate on all ranks to know the type + mpi_backend = DistributedUtils.get_distributed_backend(Val(:MPI)) + buf = [unique_id.internal...] + DistributedUtils.bcast!(mpi_backend, buf; root=0) + @set! unique_id.internal = Tuple(buf) + + nranks = DistributedUtils.total_workers(mpi_backend) + rank = DistributedUtils.local_rank(mpi_backend) + + return NCCLBackend(NCCL.Communicator(nranks, rank; unique_id)) +end + +DistributedUtils.local_rank(backend::NCCLBackend) = NCCL.rank(backend.comm) + +DistributedUtils.total_workers(backend::NCCLBackend) = NCCL.size(backend.comm) + +function DistributedUtils.bcast!(backend::NCCLBackend, sendrecvbuf; root=0) + NCCL.Broadcast!(sendrecvbuf, backend.comm; root) + return sendrecvbuf +end + +function DistributedUtils.bcast!(backend::NCCLBackend, sendbuf, recvbuf; root=0) + NCCL.Broadcast!(sendbuf, recvbuf, backend.comm; root) + return recvbuf +end + +end diff --git a/ext/LuxCUDANCCLExt.jl b/ext/LuxCUDANCCLExt.jl deleted file mode 100644 index 7010a1bf32..0000000000 --- a/ext/LuxCUDANCCLExt.jl +++ /dev/null @@ -1,11 +0,0 @@ -module LuxCUDANCCLExt - -using CUDA, Lux, NCCL - -# FIXME: Remove before merging -function __init__() - @info "CUDA information:\n" * sprint(io -> CUDA.versioninfo(io)) - @info "NCCL version: $(NCCL.version())" -end - -end diff --git a/ext/LuxLuxAMDGPUExt.jl b/ext/LuxLuxAMDGPUExt.jl index 684772fcc5..a673c5113c 100644 --- a/ext/LuxLuxAMDGPUExt.jl +++ b/ext/LuxLuxAMDGPUExt.jl @@ -1,6 +1,16 @@ module LuxLuxAMDGPUExt -using Lux, LuxAMDGPU +import Lux +import LuxAMDGPU: AMDGPU + +Lux.__is_extension_loaded(::Val{:LuxAMDGPU}) = Val(true) + +Lux.__set_device!(::Val{:AMDGPU}, id::Int) = AMDGPU.functional() && AMDGPU.device!(id) +function Lux.__set_device!(::Val{:AMDGPU}, ::Nothing, rank::Int) + AMDGPU.functional() || return + AMDGPU.device!(rank % length(AMDGPU.devices())) + return +end # Flux modifies Conv weights while mapping to AMD GPU function Lux._maybe_flip_conv_weight(x::AMDGPU.AnyROCArray) diff --git a/ext/LuxLuxCUDAExt.jl b/ext/LuxLuxCUDAExt.jl new file mode 100644 index 0000000000..81d9c88ddd --- /dev/null +++ b/ext/LuxLuxCUDAExt.jl @@ -0,0 +1,15 @@ +module LuxLuxCUDAExt + +import Lux +import LuxCUDA: CUDA + +Lux.__is_extension_loaded(::Val{:LuxCUDA}) = Val(true) + +Lux.__set_device!(::Val{:CUDA}, id::Int) = CUDA.functional() && CUDA.device!(id) +function Lux.__set_device!(::Val{:CUDA}, ::Nothing, rank::Int) + CUDA.functional() || return + CUDA.device!(rank % length(CUDA.devices())) + return +end + +end diff --git a/ext/LuxMPIExt.jl b/ext/LuxMPIExt.jl new file mode 100644 index 0000000000..1474534a30 --- /dev/null +++ b/ext/LuxMPIExt.jl @@ -0,0 +1,49 @@ +module LuxMPIExt + +import Lux: MPIBackend, NCCLBackend, DistributedUtils, __is_extension_loaded, __set_device!, + __unwrap_val +import MPI + +function DistributedUtils.__initialize( + ::Val{:MPI}; cuda_devices=nothing, amdgpu_devices=nothing) + !MPI.Initialized() && MPI.Init() + DistributedUtils.MPI_Initialized[] = true + + local_rank = MPI.Comm_rank(MPI.COMM_WORLD) + + if cuda_devices !== missing && __unwrap_val(__is_extension_loaded(Val(:LuxCUDA))) + if cuda_devices === nothing + __set_device!(Val(:CUDA), nothing, local_rank) + else + __set_device!(Val(:CUDA), cuda_devices[local_rank]) + end + end + + if amdgpu_devices !== missing && __unwrap_val(__is_extension_loaded(Val(:LuxAMDGPU))) + if amdgpu_devices === nothing + __set_device!(Val(:AMDGPU), nothing, local_rank) + else + __set_device!(Val(:AMDGPU), amdgpu_devices[local_rank]) + end + end + + return +end + +DistributedUtils.get_distributed_backend(::Val{:MPI}) = MPIBackend(MPI.COMM_WORLD) + +DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm) + +DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm) + +function DistributedUtils.bcast!(backend::MPIBackend, sendrecvbuf; root=0) + MPI.Bcast!(sendrecvbuf, backend.comm; root) + return sendrecvbuf +end + +function DistributedUtils.bcast!(backend::MPIBackend, sendbuf, recvbuf; root=0) + MPI.Bcast!(sendbuf, recvbuf, backend.comm; root) + return recvbuf +end + +end diff --git a/src/Lux.jl b/src/Lux.jl index 0874a74fe0..5a6d8968d0 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -24,6 +24,8 @@ const CRC = ChainRulesCore const NAME_TYPE = Union{Nothing, String, Symbol} +@inline __is_extension_loaded(x) = Val(false) + # Utilities include("utils.jl") @@ -50,6 +52,7 @@ include("helpers/stateful.jl") # Distributed Training include("distributed/backend.jl") +include("distributed/public_api.jl") # Deprecations include("deprecated.jl") @@ -76,6 +79,6 @@ export f16, f32, f64 export transform, FluxLayer -export NCCLBackend +export MPIBackend, NCCLBackend, DistributedUtils end diff --git a/src/contrib/compact.jl b/src/contrib/compact.jl index 6523a86db5..586c5547c6 100644 --- a/src/contrib/compact.jl +++ b/src/contrib/compact.jl @@ -52,7 +52,7 @@ Here is a linear model with bias and activation: ```julia d_in = 5 d_out = 7 -d = @compact(W=randn(d_out, d_in), b=zeros(d_out),act=relu) do x +d = @compact(W=randn(d_out, d_in), b=zeros(d_out), act=relu) do x y = W * x return act.(y .+ b) end @@ -72,7 +72,7 @@ n_out = 1 nlayers = 3 model = @compact(w1=Dense(n_in, 128), - w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out),act=relu) do x + w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x embed = act(w1(x)) for w in w2 embed = act(w(embed)) diff --git a/src/distributed/backend.jl b/src/distributed/backend.jl index 883d8f8e08..0377f33fac 100644 --- a/src/distributed/backend.jl +++ b/src/distributed/backend.jl @@ -1,10 +1,23 @@ -abstract type AbstractLuxDistributedTrainingBackend end +abstract type AbstractLuxDistributedBackend end -struct NCCLBackend <: AbstractLuxDistributedTrainingBackend - function NCCLBackend() - if Base.get_extension(@__MODULE__, :LuxCUDANCCLExt) === nothing - error("`NCCLBackend` requires `CUDA.jl` and `NCCL.jl` to be loaded") +struct NCCLBackend{C} <: AbstractLuxDistributedBackend + comm::C + + function NCCLBackend(comm=nothing) + if Base.get_extension(@__MODULE__, :LuxCUDAMPINCCLExt) === nothing + error("`NCCLBackend` requires `CUDA.jl`, `MPI.jl` and `NCCL.jl` to be loaded.") + end + return new{typeof(comm)}(comm) + end +end + +struct MPIBackend{C} <: AbstractLuxDistributedBackend + comm::C + + function MPIBackend(comm=nothing) + if Base.get_extension(@__MODULE__, :LuxMPIExt) === nothing + error("`MPIBackend` requires `MPI.jl` to be loaded.") end - return new() + return new{typeof(comm)}(comm) end end diff --git a/src/distributed/public_api.jl b/src/distributed/public_api.jl new file mode 100644 index 0000000000..a1fea196f5 --- /dev/null +++ b/src/distributed/public_api.jl @@ -0,0 +1,108 @@ +module DistributedUtils + +import ChainRulesCore as CRC +import Functors: fmap +import ..Lux: AbstractLuxDistributedBackend, MPIBackend, NCCLBackend +import Optimisers: Leaf +import Setfield: @set! + +const NCCL_Initialized = Ref(false) +const MPI_Initialized = Ref(false) + +""" + initialized(backend::Val) + +Check if the given backend is initialized. +""" +initialized(::Val{:MPI}) = MPI_Initialized[] +initialized(::Val{:NCCL}) = NCCL_Initialized[] + +function initialize(backend::Val; cuda_devices=nothing, amdgpu_devices=nothing) + initialized(backend) && return + __initialize(backend; cuda_devices, amdgpu_devices) + return +end + +function __initialize end + +""" + get_distributed_backend(backend::Val) + +Get the distributed backend for the given backend type. Possible values are: + + - `Val(:MPI)`: MPI backend for distributed training. Requires `MPI.jl` to be installed. + - `Val(:NCCL)`: NCCL backend for CUDA distributed training. Requires `CUDA.jl`, + `MPI.jl`, and `NCCL.jl` to be installed. +""" +function get_distributed_backend end + +CRC.@non_differentiable get_distributed_backend(::Any...) + +""" + local_rank(backend::AbstractLuxDistributedBackend) + +Get the local rank for the given backend. +""" +function local_rank end + +CRC.@non_differentiable local_rank(::Any...) + +""" + total_workers(backend::AbstractLuxDistributedBackend) + +Get the total number of workers for the given backend. +""" +function total_workers end + +CRC.@non_differentiable total_workers(::Any...) + +function bcast! end + +CRC.@non_differentiable bcast!(::Any...) + +function allreduce! end + +CRC.@non_differentiable allreduce!(::Any...) + +function reduce! end + +CRC.@non_differentiable reduce!(::Any...) + +# syncronize! +""" + synchronize!!(backend::AbstractLuxDistributedBackend, ps; root::Int=0) + +Synchronize the given structure `ps` using the given backend. The value at `root` will be +broadcasted to all other workers. +""" +function synchronize!!(backend::AbstractLuxDistributedBackend, ps::Tuple; root::Int=0) + length(ps) == 0 && return ps + return map(x -> synchronize!!(backend, x; root), ps) +end + +function synchronize!!(backend::AbstractLuxDistributedBackend, + ps::NamedTuple{fields}; root::Int=0) where {fields} + length(ps) == 0 && return ps + return NamedTuple{fields}(map(x -> synchronize!!(backend, x; root), values(ps))) +end + +function synchronize!!( + backend::AbstractLuxDistributedBackend, ps::AbstractArray{T}; root::Int=0) where {T} + if isbitstype(T) + bcast!(backend, ps; root) + return ps + end + return map(x -> synchronize!!(backend, x; root), ps) +end + +function synchronize!!(backend::AbstractLuxDistributedBackend, ps::Leaf; root::Int=0) + @set! ps.state = synchronize!!(backend, ps.state; root) + return ps +end + +function synchronize!!(backend::AbstractLuxDistributedBackend, ps::T; root::Int=0) where {T} + isbitstype(T) && return bcast!(backend, [ps]; root)[] + return ps # If we don't know how to synchronize, just return the value. For ex, Symbol, String, etc. +end + +end diff --git a/src/utils.jl b/src/utils.jl index 5189d987e3..8ab2af75d9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -216,3 +216,5 @@ end @inline _pairs(x) = pairs(x) @inline __value(x) = x + +function __set_device! end