Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inbuilt-Distributed Setup #500

Merged
merged 14 commits into from
Apr 7, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ docs/src/tutorials/advanced
*.log

bench/benchmark_results.json
*.cov
20 changes: 18 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Expand All @@ -43,6 +48,9 @@ LuxComponentArraysExt = "ComponentArrays"
LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"]
LuxFluxExt = "Flux"
LuxLuxAMDGPUExt = "LuxAMDGPU"
LuxMLUtilsExt = "MLUtils"
LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxOptimisersExt = "Optimisers"
LuxReverseDiffExt = "ReverseDiff"
LuxSimpleChainsExt = "SimpleChains"
Expand All @@ -54,6 +62,7 @@ ADTypes = "0.2"
Adapt = "4"
Aqua = "0.8.4"
ArrayInterface = "7.8"
CUDA = "5.2"
ChainRules = "1.62"
ChainRulesCore = "1.21"
ComponentArrays = "0.15.11"
Expand All @@ -69,14 +78,18 @@ Logging = "1.10"
LuxAMDGPU = "0.2.2"
LuxCUDA = "0.3.2"
LuxCore = "0.1.12"
LuxDeviceUtils = "0.1.16"
LuxDeviceUtils = "0.1.19"
LuxLib = "0.3.11"
LuxTestUtils = "0.1.15"
MLUtils = "0.4.3"
MPI = "0.20.19"
MacroTools = "0.5.13"
Markdown = "1.10"
NCCL = "0.1.1"
Optimisers = "0.3"
Pkg = "1.10"
PrecompileTools = "1.2"
Preferences = "1.4.3"
Random = "1.10"
ReTestItems = "1.23.1"
Reexport = "1"
Expand Down Expand Up @@ -107,6 +120,9 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -122,4 +138,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Adapt", "Aqua", "ChainRules", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "Optimisers", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
test = ["ADTypes", "Adapt", "Aqua", "ChainRules", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6 changes: 4 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ pages = [
"manual/freezing_model_parameters.md",
"manual/gpu_management.md",
"manual/migrate_from_flux.md",
"manual/weight_initializers.md"
"manual/weight_initializers.md",
"manual/distributed_utils.md"
],
"API Reference" => [
"Lux" => [
"api/Lux/layers.md",
"api/Lux/utilities.md",
"api/Lux/contrib.md",
"api/Lux/switching_frameworks.md"
"api/Lux/switching_frameworks.md",
"api/Lux/distributed_utils.md",
],
"Accelerator Support" => [
"api/Accelerator_Support/LuxAMDGPU.md",
Expand Down
9 changes: 6 additions & 3 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ export default defineConfig({
{ text: 'Built-In Layers', link: '/api/Lux/layers' },
{ text: 'Utilities', link: '/api/Lux/utilities' },
{ text: 'Experimental', link: '/api/Lux/contrib' },
{ text: 'InterOp', link: '/api/Lux/switching_frameworks' }
{ text: 'InterOp', link: '/api/Lux/switching_frameworks' },
{ text: 'DistributedUtils', link: '/api/Lux/distributed_utils' }
]
},
{
Expand Down Expand Up @@ -146,7 +147,8 @@ export default defineConfig({
{ text: 'Freezing Model Parameters', link: '/manual/freezing_model_parameters' },
{ text: 'GPU Management', link: '/manual/gpu_management' },
{ text: 'Migrating from Flux to Lux', link: '/manual/migrate_from_flux' },
{ text: 'Initializing Weights', link: '/manual/weight_initializers' }]
{ text: 'Initializing Weights', link: '/manual/weight_initializers' },
{ text: 'Distributed Data Parallel Training', link: '/manual/distributed_utils' },]
},
"/api/": {
text: 'API Reference', collapsed: false, items: [
Expand All @@ -155,7 +157,8 @@ export default defineConfig({
{ text: 'Built-In Layers', link: '/api/Lux/layers' },
{ text: 'Utilities', link: '/api/Lux/utilities' },
{ text: 'Experimental Features', link: '/api/Lux/contrib' },
{ text: 'Switching between Deep Learning Frameworks', link: '/api/Lux/switching_frameworks' }]
{ text: 'Switching between Deep Learning Frameworks', link: '/api/Lux/switching_frameworks' },
{ text: 'DistributedUtils', link: '/api/Lux/distributed_utils' }]
},
{
text: 'Accelerator Support', collapsed: false, items: [
Expand Down
58 changes: 58 additions & 0 deletions docs/src/api/Lux/distributed_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Distributed Utils

!!! note

These functionalities are available via the `Lux.DistributedUtils` module.

```@meta
CurrentModule = Lux
```

## Index

```@index
Pages = ["distributed_utils.md"]
```

## [Backends](@id communication-backends)

```@docs
MPIBackend
NCCLBackend
```

## Initialization

```@docs
DistributedUtils.initialize
DistributedUtils.initialized
DistributedUtils.get_distributed_backend
```

## Helper Functions

```@docs
DistributedUtils.local_rank
DistributedUtils.total_workers
```

## Communication Primitives

```@docs
DistributedUtils.allreduce!
DistributedUtils.bcast!
DistributedUtils.reduce!
DistributedUtils.synchronize!!
```

## Optimizers.jl Integration

```@docs
DistributedUtils.DistributedOptimizer
```

## MLUtils.jl Integration

```@docs
DistributedUtils.DistributedDataContainer
```
3 changes: 2 additions & 1 deletion docs/src/introduction/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs)
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support
import Lux.Experimental: @compact
using Printf # For pretty printing
```

We will define a custom MLP using the `@compact` macro. The macro takes in a list of
Expand Down Expand Up @@ -117,7 +118,7 @@ for epoch in 1:1000
return sum(abs2, y .- y_data), st_
end
gs = only(pb((one(loss), nothing)))
epoch % 100 == 1 && println("Epoch: $(epoch) | Loss: $(loss)")
epoch % 100 == 1 && @printf "Epoch: %04d \t Loss: %10.9g\n" epoch loss
Optimisers.update!(st_opt, ps, gs)
end
```
Expand Down
111 changes: 111 additions & 0 deletions docs/src/manual/distributed_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Distributed Data Parallel Training

!!! tip

For a fully functional example, see the
[ImageNet Training Example](https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet).

DDP Training using `Lux.DistributedUtils` is a spiritual successor to
[FluxMPI.jl](https://github.com/avik-pal/FluxMPI.jl), but has some key differences.

## Guide to Integrating DistributedUtils into your code

1. Initialize the respective backend with [`DistributedUtils.initialize`](@ref), by passing
in a backend type. It is important that you pass in the type, i.e. `NCCLBackend` and not
the object `NCCLBackend()`.

```julia
DistributedUtils.initialize(NCCLBackend)
```

2. Obtain the backend via [`DistributedUtils.get_distributed_backend`](@ref) by passing in
the type of the backend (same note as last point applies here again).

```julia
backend = DistributedUtils.get_distributed_backend(NCCLBackend)
```

It is important that you use this function instead of directly constructing the backend,
since there are certain internal states that need to be synchronized.

3. Next synchronize the parameters and states of the model. This is done by calling
[`DistributedUtils.synchronize!!`](@ref) with the backend and the respective input.

```julia
ps = DistributedUtils.synchronize!!(backend, ps)
st = DistributedUtils.synchronize!!(backend, st)
```

4. To split the data uniformly across the processes use
[`DistributedUtils.DistributedDataContainer`](@ref). Alternatively, one can manually
split the data. For the provided container to work
[`MLUtils.jl`](https://github.com/JuliaML/MLUtils.jl) must be installed and loaded.

```julia
data = DistributedUtils.DistributedDataContainer(backend, data)
```

5. Wrap the optimizer in [`DistributedUtils.DistributedOptimizer`](@ref) to ensure that the
optimizer is correctly synchronized across all processes before parameter updates. After
initializing the state of the optimizer, synchronize the state across all processes.

```julia
opt = DistributedUtils.DistributedOptimizer(backend, opt)
opt_state = Optimisers.setup(opt, ps)
opt_state = DistributedUtils.synchronize!!(backend, opt_state)
```

6. Finally change all logging and serialization code to trigger on
`local_rank(backend) == 0`. This ensures that only the master process logs and serializes
the model.

## [GPU-Aware MPI](@id gpu-aware-mpi)

If you are using a custom MPI build that supports CUDA or ROCM, you can use the following
preferences with [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl):

1. `LuxDistributedMPICUDAAware` - Set this to `true` if your MPI build is CUDA aware.
2. `LuxDistributedMPIROCMAware` - Set this to `true` if your MPI build is ROCM aware.

By default, both of these values are set to `false`.

## Migration Guide from `FluxMPI.jl`

Let's compare the changes we need to make wrt the
[FluxMPI.jl integration guide](https://avik-pal.github.io/FluxMPI.jl/dev/guide/).

1. `FluxMPI.Init` is now [`DistributedUtils.initialize`](@ref).
2. `FluxMPI.synchronize!(x)` needs to be changed to
`x_new = DistributedUtils.synchronize!!(backend, x)`.
3. [`DistributedUtils.DistributedDataContainer`](@ref),
[`DistributedUtils.local_rank`](@ref), and
[`DistributedUtils.DistributedOptimizer`](@ref) need `backend` as the first input.

And that's pretty much it!

### Removed Functionality

1. `FluxMPI.allreduce_gradients` no longer exists. Previously this was needed when CUDA
communication was flaky, with `NCCL.jl` this is no longer the case.
2. `FluxMPIFluxModel` has been removed. `DistributedUtils` no longer works with `Flux`.

### Key Differences

1. `FluxMPI.synchronize!` is now `DistributedUtils.synchronize!!` to highlight the fact
that some of the inputs are not updated in-place.
2. All of the functions now require a [communication backend](@ref communication-backends)
as input.
3. We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See
[GPU-aware MPI](@ref gpu-aware-mpi) for more information.
4. Older [`Lux.gpu`](@ref) implementations used to "just work" with `FluxMPI.jl`. We expect
[`gpu_device`](@ref) to continue working as expected, however, we recommend using
[`gpu_device`](@ref) after calling [`DistributedUtils.initialize`](@ref) to avoid any
mismatch between the device set via `DistributedUtils` and the device stores in
`LuxCUDADevice` or `LuxAMDGPUDevice`.

## Known Shortcomings

1. Currently we don't run tests with CUDA or ROCM aware MPI, use those features at your own
risk. We are working on adding tests for these features.
2. AMDGPU support is mostly experimental and causes deadlocks in certain situations, this is
being investigated. If you have a minimal reproducer for this, please open an issue.
25 changes: 14 additions & 11 deletions examples/ImageNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@ Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
JLSO = "9da8a3cd-07a3-59c0-a743-3fdc52c30d11"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JpegTurbo = "b835a17e-a41a-41e7-81f0-2f016b05efe0"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleConfig = "f2d95530-262a-480f-aff0-1c0431e662a7"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -30,21 +31,23 @@ Augmentor = "0.6"
Boltz = "0.1, 0.2, 0.3"
Configurations = "0.17"
FLoops = "0.2"
Flux = "0.14"
FluxMPI = "0.6, 0.7"
Formatting = "0.4"
FileIO = "1.16"
Format = "1.3"
Functors = "0.2, 0.3, 0.4"
Images = "0.24, 0.25, 0.26"
JLSO = "2"
Images = "0.26"
JLD2 = "0.4.46"
JpegTurbo = "0.1"
Lux = "0.4, 0.5"
LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
MLUtils = "0.2.10, 0.3, 0.4"
MPI = "0.20.19"
Metalhead = "0.9"
NCCL = "0.1.1"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2, 0.3"
Setfield = "0.8.2, 1"
ParameterSchedulers = "0.4"
Setfield = "1"
SimpleConfig = "0.1"
Statistics = "1"
Zygote = "0.6"
Loading
Loading