Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean configs documentation #1944

Merged
merged 81 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
c2d9a62
Clean BCO
qgallouedec Aug 18, 2024
e3083f1
Optional[int]
qgallouedec Aug 18, 2024
c7b2fbc
fix sft config
qgallouedec Aug 19, 2024
e7a80bb
Merge branch 'main' into clean-config
qgallouedec Aug 19, 2024
50dbc86
alignprop config
qgallouedec Aug 20, 2024
b718fba
Merge branch 'main' into clean-config
qgallouedec Aug 20, 2024
4a8aba6
upadte tempfile to work with output_dir
qgallouedec Aug 20, 2024
6ae94e9
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Aug 20, 2024
3ed49fd
Merge branch 'main' into clean-config
qgallouedec Aug 21, 2024
f847f56
clean kto config
qgallouedec Aug 21, 2024
69525f9
intro docstring
qgallouedec Aug 21, 2024
c73f43a
style
qgallouedec Aug 21, 2024
11f6e7e
reward config
qgallouedec Aug 22, 2024
946e2e5
orpo config
qgallouedec Aug 22, 2024
21df122
Merge branch 'main' into clean-config
qgallouedec Aug 26, 2024
a1bff9c
warning in trainer, not in config
qgallouedec Aug 26, 2024
006a454
cpo config
qgallouedec Aug 26, 2024
c9264ee
Merge branch 'main' into clean-config
qgallouedec Aug 27, 2024
01d8814
ppo v2
qgallouedec Aug 27, 2024
5cd9eef
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Aug 27, 2024
9bef508
model config
qgallouedec Aug 27, 2024
0a49bca
ddpo and per_device_train_batch_size (instead of (train_batch_size)
qgallouedec Aug 27, 2024
1c9bba7
Merge branch 'main' into clean-config
qgallouedec Aug 27, 2024
216856a
rloo
qgallouedec Aug 27, 2024
7270936
Online config
qgallouedec Aug 27, 2024
05bacaf
tmp_dir in test_ddpo
qgallouedec Aug 27, 2024
451b4fc
style
qgallouedec Aug 27, 2024
9e6f0a0
remove to_dict and fix post-init
qgallouedec Aug 28, 2024
2aa4544
batch size in test ddpo
qgallouedec Aug 28, 2024
97738c8
Merge branch 'main' into clean-config
qgallouedec Aug 28, 2024
098ca6a
Merge branch 'main' into clean-config
qgallouedec Aug 28, 2024
02b78ec
dpo
qgallouedec Aug 28, 2024
92ff078
style
qgallouedec Aug 28, 2024
63679fe
Merge branch 'main' into clean-config
qgallouedec Aug 29, 2024
4957a8c
`Args` -> `Parameters`
qgallouedec Aug 29, 2024
bd3693b
parameters
qgallouedec Aug 29, 2024
10468e9
ppo config
qgallouedec Aug 29, 2024
d289982
dont overwrite world size
qgallouedec Aug 29, 2024
d94985a
style
qgallouedec Aug 29, 2024
1bc063a
Merge branch 'main' into clean-config
qgallouedec Aug 29, 2024
00d2faf
outputdir in test ppo
qgallouedec Aug 29, 2024
aa98e42
output dir in ppo config
qgallouedec Aug 29, 2024
66dc235
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Aug 29, 2024
79234d1
revert non-core change (1/n)
qgallouedec Sep 3, 2024
9b3b3a7
revert non-core changes (2/n)
qgallouedec Sep 3, 2024
6aeba64
revert non-core change (3/n)
qgallouedec Sep 3, 2024
fc4d223
Merge branch 'main' into clean-config
qgallouedec Sep 3, 2024
23fbfc6
uniform max_length
qgallouedec Sep 3, 2024
136cfdc
fix uniform max_length
qgallouedec Sep 3, 2024
640999c
beta uniform
qgallouedec Sep 3, 2024
3d5618c
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Sep 3, 2024
cfe9b22
style
qgallouedec Sep 3, 2024
358b026
link to `ConstantLengthDataset`
qgallouedec Sep 3, 2024
2190bf1
uniform `dataset_num_proc`
qgallouedec Sep 3, 2024
5434969
uniform `disable_dropout`
qgallouedec Sep 3, 2024
a7e537a
`eval_packing` doc
qgallouedec Sep 3, 2024
1a86078
try latex and α in doc
qgallouedec Sep 3, 2024
7065562
try title first
qgallouedec Sep 3, 2024
2d93d3d
doesn't work
qgallouedec Sep 3, 2024
42acd10
reorganize doc
qgallouedec Sep 3, 2024
92a2206
overview
qgallouedec Sep 3, 2024
81d5147
better latex
qgallouedec Sep 3, 2024
71c110a
is_encoder_decoder uniform
qgallouedec Sep 3, 2024
e60c3b0
proper ticks
qgallouedec Sep 3, 2024
a964090
fix latex
qgallouedec Sep 3, 2024
45d4f99
uniform generate_during_eval
qgallouedec Sep 3, 2024
3bc2d30
uniform truncation_mode
qgallouedec Sep 3, 2024
66a4861
ref_model_mixup_alpha
qgallouedec Sep 3, 2024
e2d8f7f
ref_model_mixup_alpha and ref_model_sync_steps
qgallouedec Sep 3, 2024
79347d9
Uniform `model_init_kwargs` and `ref_model_init_kwargs`
qgallouedec Sep 3, 2024
9ba37a9
rpo_alpha
qgallouedec Sep 3, 2024
52f69b1
Update maximum length argument names in config files
qgallouedec Sep 3, 2024
0fabc42
Update loss_type descriptions in config files
qgallouedec Sep 3, 2024
e1abc3a
Update max_target_length to max_completion_length in CPOConfig and CP…
qgallouedec Sep 3, 2024
d618f0c
Update padding value in config files
qgallouedec Sep 3, 2024
594677c
Update precompute_ref_log_probs flag documentation
qgallouedec Sep 3, 2024
5dee9ab
Fix typos and update comments in dpo_config.py and sft_config.py
qgallouedec Sep 3, 2024
47431f8
Merge branch 'main' into clean-config
qgallouedec Sep 4, 2024
19af1fa
post init warning for `max_target_length`
qgallouedec Sep 4, 2024
34b38b0
Merge branch 'clean-config' of https://github.com/huggingface/trl int…
qgallouedec Sep 4, 2024
07c9cab
Merge branch 'main' into clean-config
qgallouedec Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 35 additions & 33 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,46 @@
title: Understanding Logs
title: Get started
- sections:
- sections:
- local: trainer
title: Overview
- local: alignprop_trainer
title: AlignProp
- local: bco_trainer
title: BCO
- local: cpo_trainer
title: CPO
- local: ddpo_trainer
title: DDPO
- local: dpo_trainer
title: DPO
- local: online_dpo_trainer
title: Online DPO
- local: orpo_trainer
title: ORPO
- local: kto_trainer
title: KTO
- local: ppo_trainer
title: PPO
- local: ppov2_trainer
title: PPOv2
- local: rloo_trainer
title: RLOO
- local: sft_trainer
title: SFT
- local: iterative_sft_trainer
title: Iterative SFT
- local: reward_trainer
title: Reward Model
title: Trainers
- local: models
title: Model Classes
- local: trainer
title: Trainer Classes
- local: reward_trainer
title: Reward Model Training
- local: sft_trainer
title: Supervised Fine-Tuning
- local: ppo_trainer
title: PPO Trainer
- local: ppov2_trainer
title: PPOv2 Trainer
- local: rloo_trainer
title: RLOO Trainer
- local: best_of_n
title: Best of N Sampling
- local: dpo_trainer
title: DPO Trainer
- local: online_dpo_trainer
title: Online DPO Trainer
- local: kto_trainer
title: KTO Trainer
- local: bco_trainer
title: BCO Trainer
- local: cpo_trainer
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: alignprop_trainer
title: AlignProp Trainer
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
title: Iterative Supervised Fine-Tuning
- local: callbacks
title: Callback Classes
- local: judges
title: Judge Classes
title: Judges
- local: callbacks
title: Callbacks
- local: text_environments
title: Text Environments
title: API
Expand Down
6 changes: 2 additions & 4 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_cpo(self):
max_length=256,
max_prompt_length=64,
max_completion_length=64,
max_target_length=64,
beta=0.5,
label_smoothing=0.5,
loss_type="hinge",
Expand All @@ -96,7 +95,6 @@ def test_cpo(self):
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.max_prompt_length, 64)
self.assertEqual(trainer.args.max_completion_length, 64)
self.assertEqual(trainer.args.max_target_length, 64)
self.assertEqual(trainer.args.beta, 0.5)
self.assertEqual(trainer.args.label_smoothing, 0.5)
self.assertEqual(trainer.args.loss_type, "hinge")
Expand Down Expand Up @@ -127,7 +125,7 @@ def test_dpo(self):
truncation_mode="keep_start",
max_length=256,
max_prompt_length=64,
max_target_length=64,
max_completion_length=64,
is_encoder_decoder=True,
disable_dropout=False,
# generate_during_eval=True, # ignore this one, it requires wandb
Expand Down Expand Up @@ -155,7 +153,7 @@ def test_dpo(self):
self.assertEqual(trainer.args.truncation_mode, "keep_start")
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.max_prompt_length, 64)
self.assertEqual(trainer.args.max_target_length, 64)
self.assertEqual(trainer.args.max_completion_length, 64)
self.assertEqual(trainer.args.is_encoder_decoder, True)
self.assertEqual(trainer.args.disable_dropout, False)
# self.assertEqual(trainer.args.generate_during_eval, True)
Expand Down
124 changes: 78 additions & 46 deletions trl/trainer/alignprop_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,117 @@
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Any, Dict, Literal, Optional, Tuple

from ..core import flatten_dict
from ..import_utils import is_bitsandbytes_available, is_torchvision_available


@dataclass
class AlignPropConfig:
"""
Configuration class for AlignPropTrainer
r"""
Configuration class for the [`AlignPropTrainer`].

Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
Name of this experiment (defaults to the file name without the extension).
run_name (`str`, *optional*, defaults to `""`):
Name of this run.
log_with (`Optional[Literal["wandb", "tensorboard"]]`, *optional*, defaults to `None`):
Log with either `"wandb"` or `"tensorboard"`. Check
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
log_image_freq (`int`, *optional*, defaults to `1`):
Frequency for logging images.
tracker_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the tracker (e.g., `wandb_project`).
accelerator_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the accelerator.
project_kwargs (`Dict[str, Any]`, *optional*, defaults to `{}`):
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
Name of project to use for tracking.
logdir (`str`, *optional*, defaults to `"logs"`):
Top-level logging directory for checkpoint saving.
num_epochs (`int`, *optional*, defaults to `100`):
Number of epochs to train.
save_freq (`int`, *optional*, defaults to `1`):
Number of epochs between saving model checkpoints.
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
Number of checkpoints to keep before overwriting old ones.
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
Mixed precision training.
allow_tf32 (`bool`, *optional*, defaults to `True`):
Allow `tf32` on Ampere GPUs.
resume_from (`str`, *optional*, defaults to `""`):
Path to resume training from a checkpoint.
sample_num_steps (`int`, *optional*, defaults to `50`):
Number of sampler inference steps.
sample_eta (`float`, *optional*, defaults to `1.0`):
Eta parameter for the DDIM sampler.
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
Classifier-free guidance weight.
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
Learning rate.
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
Beta1 for Adam optimizer.
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
Beta2 for Adam optimizer.
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
Weight decay for Adam optimizer.
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
Epsilon value for Adam optimizer.
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
Number of gradient accumulation steps.
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
Maximum gradient norm for gradient clipping.
negative_prompts (`Optional[str]`, *optional*, defaults to `None`):
Comma-separated list of prompts to use as negative examples.
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
If `True`, randomized truncation to different diffusion timesteps is used.
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
truncated_rand_backprop_minmax (`Tuple[int, int]`, *optional*, defaults to `(0, 50)`):
Range of diffusion timesteps for randomized truncated backpropagation.
"""

# common parameters
exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
"""the name of this experiment (by default is the file name without the extension name)"""
run_name: Optional[str] = ""
"""Run name for wandb logging and checkpoint saving."""
run_name: str = ""
seed: int = 0
"""Seed value for random generations"""
log_with: Optional[Literal["wandb", "tensorboard"]] = None
"""Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
log_image_freq = 1
"""Logging Frequency for images"""
tracker_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the tracker (e.g. wandb_project)"""
accelerator_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator"""
project_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
log_image_freq: int = 1
tracker_kwargs: Dict[str, Any] = field(default_factory=dict)
accelerator_kwargs: Dict[str, Any] = field(default_factory=dict)
project_kwargs: Dict[str, Any] = field(default_factory=dict)
tracker_project_name: str = "trl"
"""Name of project to use for tracking"""
logdir: str = "logs"
"""Top-level logging directory for checkpoint saving."""

# hyperparameters
num_epochs: int = 100
"""Number of epochs to train."""
save_freq: int = 1
"""Number of epochs between saving model checkpoints."""
num_checkpoint_limit: int = 5
"""Number of checkpoints to keep before overwriting old ones."""
mixed_precision: str = "fp16"
"""Mixed precision training."""
allow_tf32: bool = True
"""Allow tf32 on Ampere GPUs."""
resume_from: Optional[str] = ""
"""Resume training from a checkpoint."""
resume_from: str = ""
sample_num_steps: int = 50
"""Number of sampler inference steps."""
sample_eta: float = 1.0
"""Eta parameter for the DDIM sampler."""
sample_guidance_scale: float = 5.0
"""Classifier-free guidance weight."""
train_batch_size: int = 1
"""Batch size (per GPU!) to use for training."""
train_use_8bit_adam: bool = False
"""Whether to use the 8bit Adam optimizer from bitsandbytes."""
train_learning_rate: float = 1e-3
"""Learning rate."""
train_adam_beta1: float = 0.9
"""Adam beta1."""
train_adam_beta2: float = 0.999
"""Adam beta2."""
train_adam_weight_decay: float = 1e-4
"""Adam weight decay."""
train_adam_epsilon: float = 1e-8
"""Adam epsilon."""
train_gradient_accumulation_steps: int = 1
"""Number of gradient accumulation steps."""
train_max_grad_norm: float = 1.0
"""Maximum gradient norm for gradient clipping."""
negative_prompts: Optional[str] = ""
"""Comma-separated list of prompts to use as negative examples."""
negative_prompts: Optional[str] = None
truncated_backprop_rand: bool = True
"""Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps"""
truncated_backprop_timestep: int = 49
"""Absolute timestep to which the gradients are being backpropagated. If truncated_backprop_rand is False"""
truncated_rand_backprop_minmax: tuple = (0, 50)
"""Range of diffusion timesteps for randomized truncated backprop."""
truncated_rand_backprop_minmax: Tuple[int, int] = (0, 50)

def to_dict(self):
output_dict = {}
Expand Down
Loading
Loading