Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full DPO Distributed #2275

Merged
merged 39 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b96255b
full dpo configs, distributed recipe, and integration tests
sam-pi Jan 17, 2025
761b718
disable dropout, ref model setup, minor doc update
sam-pi Jan 23, 2025
753e822
Merge remote-tracking branch 'upstream/main' into HEAD
SalmanMohammadi Jan 28, 2025
0f90093
updating full recipe
SalmanMohammadi Jan 28, 2025
ebed89c
updating recipe
SalmanMohammadi Jan 30, 2025
aff595f
removing 70B full dpo config until multi-node support is available
sam-pi Jan 30, 2025
431f269
minor update to avoid _ref_model self reference
sam-pi Jan 30, 2025
c63e9e8
clean up rank zero logs and ref_checkpointer
sam-pi Jan 30, 2025
ebf288a
remove unncessary save/load test and update to 2 GPUs
sam-pi Jan 30, 2025
ba12bb4
fix: Metrics weren't running and synced across devices
bogdansalyp Jan 30, 2025
6139096
fix: Fixed tokens_per_second_per_gpu
bogdansalyp Jan 31, 2025
2a4ca92
fix: Fixed torch.distributed naming
bogdansalyp Jan 31, 2025
7f94b07
fix: tokens_per_second_pre_gpu fixed for full dpo
bogdansalyp Jan 31, 2025
1a673df
fix: Added running metrics to full_dpo_distributed
bogdansalyp Jan 31, 2025
16821c4
Merge pull request #2 from bogdansalyp/fix/running_metrics_and_sync_l…
sam-pi Feb 1, 2025
d052271
fix: num_tokens all_reduce crash in DPO recipes
bogdansalyp Feb 3, 2025
f9fedc4
Merge pull request #3 from bogdansalyp/fix/num_tokens-tensor-issue
sam-pi Feb 3, 2025
a0ac5aa
delete ref logits and improved default full dpo config
sam-pi Feb 3, 2025
5af4bed
remove 70B full DPO for now
sam-pi Feb 4, 2025
7dd27bb
Update recipes/full_dpo_distributed.py
sam-pi Feb 4, 2025
86e7639
Update recipes/full_dpo_distributed.py
sam-pi Feb 4, 2025
fb228c6
minor docs update and config comment
sam-pi Feb 4, 2025
85ce53d
fix linting
sam-pi Feb 5, 2025
ad7de59
explicitly specify 2 GPUs for full DPO test
sam-pi Feb 6, 2025
4e67853
reduce dpo test VRAM usage
sam-pi Feb 6, 2025
4667adf
Update recipes/configs/llama3_1/8B_full_dpo.yaml
SalmanMohammadi Feb 6, 2025
f2e9f47
Update docs/source/recipes/dpo.rst
SalmanMohammadi Feb 6, 2025
99a87e8
use get_lr, remove ac_mode, add clip_grad_norm
sam-pi Feb 6, 2025
8e46c93
average running loss across ranks
sam-pi Feb 6, 2025
716efca
adding opt in bwd
SalmanMohammadi Feb 6, 2025
eef1b01
fixing test
SalmanMohammadi Feb 6, 2025
46c59ec
removing grad clip
SalmanMohammadi Feb 6, 2025
6781894
Updating test
SalmanMohammadi Feb 6, 2025
4fd18bf
fixing test... round 2
SalmanMohammadi Feb 7, 2025
9c71083
fixing typo
SalmanMohammadi Feb 7, 2025
b293ce3
updating distributed test to correctly resume from checkpoint
SalmanMohammadi Feb 7, 2025
893398e
updating tests and recipe
SalmanMohammadi Feb 7, 2025
abbdf11
revert optimizer_in_bwd update for now
sam-pi Feb 7, 2025
d463e70
activations handling only for policy forward pass
sam-pi Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/source/recipes/dpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ To see the best results when using this recipe, it may be helpful to first fine-
on-distribution for the domain you're interested in. To do this, check out our other fine-tuning recipes in the :ref:`recipe overview <recipes_overview_label>` which
support a variety of SFT paradigms.

After supervised fine-tuning, here is an example of DPO with Llama 3.1 8B:
After supervised fine-tuning, here is an example of using either LoRA-based finetuning, or full-finetuning Llama 3.1 8B with DPO:

.. note::

Expand All @@ -27,12 +27,15 @@ After supervised fine-tuning, here is an example of DPO with Llama 3.1 8B:
--ignore-patterns "original/consolidated.00.pth"
--HF_TOKEN <HF_TOKEN>

# run on a single device
# run lora dpo on a single device
tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device

# run on two gpus
# run lora dpo on two gpus
tune run --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo

# run full dpo on four gpus
tune run --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo

It's easy to get started with this recipe with your dataset of choice, including custom local datasets,
and datasets from Hugging Face. Check out our primer on :ref:`preference datasets <preference_dataset_usage_label>` to
see how to do this.
Expand Down
98 changes: 98 additions & 0 deletions recipes/configs/llama3_1/8B_full_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Config for multi-device full DPO alignment in full_dpo_distributed.py
# using a Llama3.1 8B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#

output_dir: /tmp/torchtune/llama3_1_8B/full_dpo # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
_component_: torchtune.models.llama3_1.llama3_1_8b

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False

# The ref_checkpointer should always point to the original weights.
ref_checkpointer:
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 2e-5
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 20

loss:
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.05
label_smoothing: 0

# Training
epochs: 1
max_steps_per_epoch: 1000
gradient_accumulation_steps: 8 # Use to increase effective batch size
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
Loading