Skip to content

Commit d5f4b27

Browse files
Add minimal FSDP example (#23)
1 parent 6caaab1 commit d5f4b27

File tree

5 files changed

+623
-0
lines changed

5 files changed

+623
-0
lines changed

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,9 @@ This repository contains a variety of Determined examples that are not actively
8888
| Example | Dataset | Framework |
8989
|:------------------------------------------------------------------------:|:-------:|:----------:|
9090
| [asha\_search\_method](custom_search_method/asha_search_method) | MNIST | PyTorch |
91+
92+
## Fully Sharded Data Parallel
93+
94+
| Example | Framework |
95+
|:------------------------------------------------------------------------:|:----------:|
96+
| [minimal\_fsdp](fsdp/minimal_fsdp) | PyTorch |

fsdp/minimal_fsdp/README.md

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# FSDP + Core API for LLM Training
2+
3+
This example shows how to use Fully Sharded Data Parallel [(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) with Determined and Core API. (Relatively) simple transformer model adapted from [GPT-fast
4+
](https://github.com/pytorch-labs/gpt-fast) training on fake data.
5+
6+
## Files
7+
* **fsdp.py**: Training setup and loop, including checkpointing, reporting, and profiling.
8+
* **model.py**: Model architecture.
9+
* **config.yaml**: Experiment configuration file.
10+
11+
## Configuration
12+
Settings can be changed in `config.yaml` `hyperparameters` section.
13+
14+
### Hyperparameters
15+
* `batch_size`: Per-device batch size. Global batch size will be `batch_size * slots_per_trial`.
16+
* `lr`: Learning rate.
17+
* `d_model`, `max_seq_len`, `n_heads`, `n_layers`, `vocab_size`: Model architecture parameters. Check code for more details.
18+
* `report_rate`: Number of training steps to take between metric reports.
19+
* `checkpoint_rate`: Number of training steps to take between checkpoint saves.
20+
* `amp_dtype`: Whether to use torch automatic mixed-precision, and which dtype to use. Options are `'auto'`, `'bfloat16'`, `'float16'`, and `null`.
21+
* `validation_batches`: Number of batches to use when calculating validation metrics.
22+
* `core_api_profiler`: Set to true to enable Core API profiler. Results visible in Web UI.
23+
* `torch_profiler`: Set to true to enable `torch` profiler. Results visible in Tensorboard, which can be launched through the Web UI.
24+
25+
### Other configuration
26+
Users might want to change `resources.slots_per_trial`, `workspace`, `project`, and `searcher.max_length` in `config.yaml`.
27+
28+
## Data
29+
This example uses a synthetically generated random dataset for simplicity.
30+
31+
## To Run
32+
If you have not yet installed Determined, installation instructions can be found at https://docs.determined.ai/latest/index.html
33+
34+
Change any desired configuration variables as outlined in the **Configuration** section, then run the following command: `det -m <master-host:port> experiment create
35+
config.yaml .`.
36+
37+
38+
## Results
39+
Training loss should decrease from ~10.5 to ~8.5 with default settings run for 100 steps, while validation loss remains constant. This is due to validation data being a separate random dataset.

fsdp/minimal_fsdp/config.yaml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: fsdp example
2+
entrypoint: python3 -m determined.launch.torch_distributed -- python3 fsdp.py
3+
searcher:
4+
name: single
5+
metric: loss
6+
max_length: 100
7+
resources:
8+
slots_per_trial: 2
9+
environment:
10+
image:
11+
gpu: determinedai/environments:cuda-11.8-pytorch-2.0-gpu-mpi-0.31.1
12+
hyperparameters:
13+
batch_size: 1
14+
lr: 1e-4
15+
d_model: 512
16+
max_seq_len: 2048
17+
n_heads: 8
18+
n_layers: 4
19+
vocab_size: 32000
20+
report_rate: 10
21+
checkpoint_rate: 50
22+
amp_dtype: float16
23+
validation_batches: 10
24+
core_api_profiler: false
25+
torch_profiler: false
26+
max_restarts: 0

0 commit comments

Comments
 (0)