Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding precompute batch size argument in DPOTrainer for reference model #2421

Closed
SwayamInSync opened this issue Dec 1, 2024 · 2 comments · Fixed by #2426
Closed

Adding precompute batch size argument in DPOTrainer for reference model #2421

SwayamInSync opened this issue Dec 1, 2024 · 2 comments · Fixed by #2426
Labels
🏋 DPO Related to DPO ✨ enhancement New feature or request

Comments

@SwayamInSync
Copy link
Contributor

Feature request

Proposing adding a new configuration parameter precompute_ref_batch_size to allow users to specify a different (likely larger) batch size specifically for the reference model precomputation phase. This would:

  1. Speed up the precomputation phase by processing more examples per batch
  2. Make better use of available GPU memory since no gradients need to be stored
  3. Maintain backward compatibility by defaulting to current behavior if not specified

The change would affect:

  • DPOConfig: Add new optional parameter precompute_ref_batch_size
  • get_train_dataloader() and get_eval_dataloader(): Use the new batch size when precomputing reference probabilities

This is particularly useful for large-scale DPO training where the precomputation phase can be a significant bottleneck.

Motivation

Currently when precompute_ref_log_probs=True the DPOTrainer class uses the per_device_train_batch_size or per_device_eval_batch_size respectively during training or evaluation for generating the logprobs.
But for efficiency this batch size can be more than the used for training one since this step does not require gradient computation and storage in memory. Adding a configurable precompute_ref_batch_size parameter would allow users to optimize this preprocessing step by using larger batch sizes while maintaining memory efficiency.

Your contribution

Yes, I can help by submitting a PR following the contribution guidelines.

@qgallouedec
Copy link
Member

Thanks for this suggestion @SwayamInSync!
Do you have any idea of the gain in speed?
If you've a working implementation, feel free to submit a PR so that we can test and discuss the code

@SwayamInSync
Copy link
Contributor Author

Thanks for this suggestion @SwayamInSync! Do you have any idea of the gain in speed? If you've a working implementation, feel free to submit a PR so that we can test and discuss the code

Made a PR at #2426
From a quick test on my settings I can fit only a batch size upto 8 (get OOM elsewise) but with this new parameter, inference batch size can go upto 32 (instead of same as train, so pretty better than before I guess)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO ✨ enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants