-
Notifications
You must be signed in to change notification settings - Fork 429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apply quantization during DPO QLoRA #115
Conversation
@@ -1,12 +1,12 @@ | |||
# Model arguments | |||
model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora | |||
torch_dtype: float16 | |||
torch_dtype: bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I turns out that using bfloat16
makes a non-trivial difference to downstream perf! cc @nathan-az :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
model_kwargs = dict( | ||
revision=model_args.base_model_revision, | ||
trust_remote_code=model_args.trust_remote_code, | ||
use_flash_attention_2=model_args.use_flash_attention_2, | ||
torch_dtype=torch_dtype, | ||
use_cache=False if training_args.gradient_checkpointing else True, | ||
device_map=get_kbit_device_map() if quantization_config is not None else None, | ||
quantization_config=quantization_config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this approach of quantizing and then merging in DPOTrainer
is what Tim Dettmers suggests: https://twitter.com/Tim_Dettmers/status/1694654191325573456
|
||
# LoRA arguments | ||
use_peft: true | ||
load_in_4bit: true | ||
lora_r: 16 | ||
lora_alpha: 16 | ||
lora_r: 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tuning these hparams was necessary to get close to zephyr-7b-beta
perf on MT-Bench
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR fixes a bug where we weren't quantising the base model with QLoRA during DPO and thus were actually doing LoRA instead.
Now we first quantise the base model in 4bit and load the SFT adapter (which later gets merged within the
DPOTrainer
). Although this isn't as memory efficient as loading two adapters in a single base model (example), it does provide the flexibility to customise the QLoRA config.I find that with these settings MT-Bench yields a score of 7.212, which is ~0.1 lower than
zephyr-7b-beta
and could likely be improved with a bit more tuning of hparams.