Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 6, 2024
1 parent 6293e2a commit 5bd71be
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 10 deletions.
6 changes: 4 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ pages = [
"manual/freezing_model_parameters.md",
"manual/gpu_management.md",
"manual/migrate_from_flux.md",
"manual/weight_initializers.md"
"manual/weight_initializers.md",
"manual/distributed_utils.md"
],
"API Reference" => [
"Lux" => [
"api/Lux/layers.md",
"api/Lux/utilities.md",
"api/Lux/contrib.md",
"api/Lux/switching_frameworks.md"
"api/Lux/switching_frameworks.md",
"api/Lux/distributed_utils.md",
],
"Accelerator Support" => [
"api/Accelerator_Support/LuxAMDGPU.md",
Expand Down
9 changes: 6 additions & 3 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ export default defineConfig({
{ text: 'Built-In Layers', link: '/api/Lux/layers' },
{ text: 'Utilities', link: '/api/Lux/utilities' },
{ text: 'Experimental', link: '/api/Lux/contrib' },
{ text: 'InterOp', link: '/api/Lux/switching_frameworks' }
{ text: 'InterOp', link: '/api/Lux/switching_frameworks' },
{ text: 'DistributedUtils', link: '/api/Lux/distributed_utils' }
]
},
{
Expand Down Expand Up @@ -146,7 +147,8 @@ export default defineConfig({
{ text: 'Freezing Model Parameters', link: '/manual/freezing_model_parameters' },
{ text: 'GPU Management', link: '/manual/gpu_management' },
{ text: 'Migrating from Flux to Lux', link: '/manual/migrate_from_flux' },
{ text: 'Initializing Weights', link: '/manual/weight_initializers' }]
{ text: 'Initializing Weights', link: '/manual/weight_initializers' },
{ text: 'Distributed Data Parallel Training', link: '/manual/distributed_utils' },]
},
"/api/": {
text: 'API Reference', collapsed: false, items: [
Expand All @@ -155,7 +157,8 @@ export default defineConfig({
{ text: 'Built-In Layers', link: '/api/Lux/layers' },
{ text: 'Utilities', link: '/api/Lux/utilities' },
{ text: 'Experimental Features', link: '/api/Lux/contrib' },
{ text: 'Switching between Deep Learning Frameworks', link: '/api/Lux/switching_frameworks' }]
{ text: 'Switching between Deep Learning Frameworks', link: '/api/Lux/switching_frameworks' },
{ text: 'DistributedUtils', link: '/api/Lux/distributed_utils' }]
},
{
text: 'Accelerator Support', collapsed: false, items: [
Expand Down
58 changes: 58 additions & 0 deletions docs/src/api/Lux/distributed_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Distributed Utils

!!! note

These functionalities are available via the `Lux.DistributedUtils` module.

```@meta
CurrentModule = Lux
```

## Index

```@index
Pages = ["distributed_utils.md"]
```

## Backends

```@docs
MPIBackend
NCCLBackend
```

## Initialization

```@docs
DistributedUtils.initialize
DistributedUtils.initialized
DistributedUtils.get_distributed_backend
```

## Helper Functions

```@docs
DistributedUtils.local_rank
DistributedUtils.total_workers
```

## Communication Primitives

```@docs
DistributedUtils.allreduce!
DistributedUtils.bcast!
DistributedUtils.reduce!
DistributedUtils.synchronize!!
```

## Optimizers.jl Integration

```@docs
DistributedUtils.DistributedOptimizer
```

## MLUtils.jl Integration

```@docs
DistributedUtils.DistributedDataLoader
```
15 changes: 15 additions & 0 deletions docs/src/manual/distributed_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Distributed Data Parallel Training

!!! tip

For a fully functional example, see the
[ImageNet Training Example](https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet)

DDP Training using `Lux.DistributedUtils` is a spiritual successor to
[FluxMPI.jl](https://github.com/avik-pal/FluxMPI.jl), but has some key differences.

## Backends Supported

## Guide to Integrating DistributedUtils into your code

## Main Differences from `FluxMPI.jl`
6 changes: 3 additions & 3 deletions ext/LuxMPINCCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ using Setfield: @set!

function DistributedUtils.__initialize(
::Type{NCCLBackend}; cuda_devices=nothing, amdgpu_devices=missing)
DistributedUtils.NCCL_Initialized[] = true
@assert amdgpu_devices===missing "`AMDGPU` is not supported by `NCCL`."
DistributedUtils.__initialize(Val(:MPI); cuda_devices, amdgpu_devices)
DistributedUtils.__initialize(MPIBackend; cuda_devices, amdgpu_devices)
DistributedUtils.NCCL_Initialized[] = true
return
end

function DistributedUtils.__get_distributed_backend(::Type{NCCLBackend})
unique_id = NCCL.UniqueID() # Generate on all ranks to know the type
mpi_backend = DistributedUtils.__get_distributed_backend(Val(:MPI))
mpi_backend = DistributedUtils.__get_distributed_backend(MPIBackend)
buf = [unique_id.internal...]
DistributedUtils.bcast!(mpi_backend, buf; root=0)
@set! unique_id.internal = Tuple(buf)
Expand Down
4 changes: 2 additions & 2 deletions src/distributed/backend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ abstract type AbstractLuxDistributedBackend end
MPIBackend(comm = nothing)
Create an MPI backend for distributed training. Users should not use this function directly.
Instead use [`DistributedUtils.get_distributed_backend(Val(:NCCL))`](@ref).
Instead use [`DistributedUtils.get_distributed_backend(MPIBackend)`](@ref).
"""
struct MPIBackend{C} <: AbstractLuxDistributedBackend
comm::C
Expand All @@ -21,7 +21,7 @@ end
NCCLBackend(comm = nothing, mpi_backend = nothing)
Create an NCCL backend for distributed training. Users should not use this function
directly. Instead use [`DistributedUtils.get_distributed_backend(Val(:NCCL))`](@ref).
directly. Instead use [`DistributedUtils.get_distributed_backend(NCCLBackend)`](@ref).
"""
struct NCCLBackend{C, M <: Union{Nothing, MPIBackend}} <: AbstractLuxDistributedBackend
comm::C
Expand Down
8 changes: 8 additions & 0 deletions src/distributed/public_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ end
`data` must be compatible with `MLUtils` interface. The returned container is compatible
with `MLUtils` interface and is used to partition the dataset across the available
processes.
!!! danger
`MLUtils.jl` must be installed and loaded before using this.
"""
@concrete struct DistributedDataContainer
data
Expand Down Expand Up @@ -250,6 +254,10 @@ averages the gradients across the processes using Allreduce.
## Arguments
- `optimizer`: An Optimizer compatible with the Optimisers.jl package
!!! danger
`Optimisers.jl` must be installed and loaded before using this.
"""
function DistributedOptimizer(backend::AbstractLuxDistributedBackend, opt)
mod = Base.get_extension(@__MODULE__, :LuxOptimisersExt)
Expand Down

0 comments on commit 5bd71be

Please sign in to comment.