Skip to content

Commit

Permalink
Apply quantization during DPO QLoRA (#115)
Browse files Browse the repository at this point in the history
* Add QLoRA fix

* Update script
  • Loading branch information
lewtun authored Feb 5, 2024
1 parent d00e6f0 commit 87cc800
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
8 changes: 4 additions & 4 deletions recipes/zephyr-7b-beta/dpo/config_qlora.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Model arguments
model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora
torch_dtype: float16
torch_dtype: bfloat16

# LoRA arguments
use_peft: true
load_in_4bit: true
lora_r: 16
lora_alpha: 16
lora_r: 128
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
Expand All @@ -32,7 +32,7 @@ beta: 0.01
do_eval: true
evaluation_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 2
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
Expand Down
14 changes: 6 additions & 8 deletions scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,28 +128,26 @@ def main():

model = model_args.model_name_or_path
if is_adapter_model(model, model_args.model_revision) is True:
# Load the base model, merge the adapter weights and unload the adapter
# Note: to run QLoRA, you will need to merge the base model separately as the merged model in 16bit
logger.info(f"Merging PEFT adapters for {model_args.model_name_or_path=}")

logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}")
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)

model_kwargs = dict(
revision=model_args.base_model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
base_model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
**model_kwargs,
)
model = PeftModel.from_pretrained(
base_model, model_args.model_name_or_path, revision=model_args.model_revision
base_model,
model_args.model_name_or_path,
revision=model_args.model_revision,
)
model.eval()
model = model.merge_and_unload()
model_kwargs = None

ref_model = model
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"tensorboard",
"torch==2.1.2",
"transformers==4.36.2",
"trl==0.7.7",
"trl==0.7.10",
"jinja2>=3.0.0",
"tqdm>=4.64.1",
]
Expand Down

0 comments on commit 87cc800

Please sign in to comment.