|
| 1 | +# FSDP + Core API for LLM Training |
| 2 | + |
| 3 | +This example shows how to use Fully Sharded Data Parallel [(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) with Determined and Core API. (Relatively) simple transformer model adapted from [GPT-fast |
| 4 | +](https://github.com/pytorch-labs/gpt-fast) training on fake data. |
| 5 | + |
| 6 | +## Files |
| 7 | +* **fsdp.py**: Training setup and loop, including checkpointing, reporting, and profiling. |
| 8 | +* **model.py**: Model architecture. |
| 9 | +* **config.yaml**: Experiment configuration file. |
| 10 | + |
| 11 | +## Configuration |
| 12 | +Settings can be changed in `config.yaml` `hyperparameters` section. |
| 13 | + |
| 14 | +### Hyperparameters |
| 15 | +* `batch_size`: Per-device batch size. Global batch size will be `batch_size * slots_per_trial`. |
| 16 | +* `lr`: Learning rate. |
| 17 | +* `d_model`, `max_seq_len`, `n_heads`, `n_layers`, `vocab_size`: Model architecture parameters. Check code for more details. |
| 18 | +* `report_rate`: Number of training steps to take between metric reports. |
| 19 | +* `checkpoint_rate`: Number of training steps to take between checkpoint saves. |
| 20 | +* `amp_dtype`: Whether to use torch automatic mixed-precision, and which dtype to use. Options are `'auto'`, `'bfloat16'`, `'float16'`, and `null`. |
| 21 | +* `validation_batches`: Number of batches to use when calculating validation metrics. |
| 22 | +* `core_api_profiler`: Set to true to enable Core API profiler. Results visible in Web UI. |
| 23 | +* `torch_profiler`: Set to true to enable `torch` profiler. Results visible in Tensorboard, which can be launched through the Web UI. |
| 24 | + |
| 25 | +### Other configuration |
| 26 | +Users might want to change `resources.slots_per_trial`, `workspace`, `project`, and `searcher.max_length` in `config.yaml`. |
| 27 | + |
| 28 | +## Data |
| 29 | +This example uses a synthetically generated random dataset for simplicity. |
| 30 | + |
| 31 | +## To Run |
| 32 | +If you have not yet installed Determined, installation instructions can be found at https://docs.determined.ai/latest/index.html |
| 33 | + |
| 34 | +Change any desired configuration variables as outlined in the **Configuration** section, then run the following command: `det -m <master-host:port> experiment create |
| 35 | +config.yaml .`. |
| 36 | + |
| 37 | + |
| 38 | +## Results |
| 39 | +Training loss should decrease from ~10.5 to ~8.5 with default settings run for 100 steps, while validation loss remains constant. This is due to validation data being a separate random dataset. |
0 commit comments