Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💧 Generalize disable_dropout #2511

Merged
merged 1 commit into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions trl/trainer/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class BCOConfig(TrainingArguments):
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
This argument is required if you want to use the default data collator.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model and reference model.
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from both the model and the reference model to W&B during
evaluation.
Expand Down Expand Up @@ -78,6 +80,7 @@ class BCOConfig(TrainingArguments):
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
disable_dropout: bool = True
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
precompute_ref_log_probs: bool = False
Expand Down
11 changes: 5 additions & 6 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,6 @@ class BCOTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
Expand Down Expand Up @@ -538,10 +536,11 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# disable dropout in the model and reference model
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

Expand Down
1 change: 1 addition & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def make_inputs_require_grad(module, input, output):
if data_collator is None:
data_collator = PreferenceCollator(pad_token_id=self.padding_value)

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/gkd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GKDConfig(SFTConfig):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
from a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether or not to disable dropouts in `model`.
Whether to disable dropout in the model.
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
on teacher-generated output).
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
else:
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(self.model)

Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class KTOConfig(TrainingArguments):
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
Whether to disable dropout in the model and reference model.
"""

learning_rate: float = 1e-6
Expand Down
3 changes: 1 addition & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ class KTOTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
Expand Down Expand Up @@ -526,6 +524,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class OnlineDPOConfig(TrainingArguments):
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
Whether to disable dropout in the model and reference model.
"""

learning_rate: float = 5e-7
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ def __init__(
# Get peft model with the given config
model = get_peft_model(model, peft_config)

# Disable dropout in the model if specified
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

# Handle the ref_model
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)

Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/prm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class PRMConfig(TrainingArguments):
Maximum length of the sequences (prompt + completion) used for truncation.
max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
step_separator (`str`, *optional*, defaults to `"\n"`):
Separator used to separate each step of the reasoning process.
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
Expand All @@ -46,6 +48,7 @@ class PRMConfig(TrainingArguments):
learning_rate: float = 1e-5
max_length: Optional[int] = None
max_completion_length: Optional[int] = None
disable_dropout: bool = True
step_separator: str = "\n"
train_on_last_step_only: bool = False
dataset_num_proc: Optional[int] = None
6 changes: 5 additions & 1 deletion trl/trainer/prm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from transformers.utils import is_peft_available

from .prm_config import PRMConfig
from .utils import compute_accuracy, generate_model_card
from .utils import compute_accuracy, disable_dropout_in_model, generate_model_card


if is_peft_available():
Expand Down Expand Up @@ -130,6 +130,10 @@ def __init__(

model = get_peft_model(model, peft_config)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

if compute_metrics is None:
compute_metrics = compute_accuracy

Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/reward_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class RewardConfig(TrainingArguments):
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
dataset_num_proc (`int`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
Expand All @@ -42,6 +44,7 @@ class RewardConfig(TrainingArguments):
"""

max_length: Optional[int] = None
disable_dropout: bool = True
dataset_num_proc: Optional[int] = None
center_rewards_coefficient: Optional[float] = None
remove_unused_columns: bool = False
5 changes: 5 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
RewardDataCollatorWithPadding,
compute_accuracy,
decode_and_strip_padding,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
Expand Down Expand Up @@ -169,6 +170,10 @@ def __init__(

model = get_peft_model(model, peft_config)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

if compute_metrics is None:
compute_metrics = compute_accuracy

Expand Down
Loading