Skip to content

Commit

Permalink
[Doc] Add how to use Lion optimizer (#152)
Browse files Browse the repository at this point in the history
* add Lion optimizer

* remove files

* revert changes

* update docs

* use html instead

* correct link

* Update customization.mdx

* Update customization.mdx

* Update customization.mdx

* Update docs/source/customization.mdx

Co-authored-by: Leandro von Werra <[email protected]>

* Update docs/source/customization.mdx

---------

Co-authored-by: Leandro von Werra <[email protected]>
  • Loading branch information
younesbelkada and lvwerra authored Feb 21, 2023
1 parent 7953151 commit 9eaea2e
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion docs/source/customization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
```

### Use LION optimizer

You can use the new [LION optimizer from Google](https://arxiv.org/abs/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
```python
optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate)

...
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
```
We advice you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-lion.png">
</div>


## Add a learning rate scheduler

You can also play with your training by adding learning rate schedulers!
Expand Down Expand Up @@ -101,8 +117,14 @@ ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

## Pass 8-bit reference models

Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
<div>

Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.

Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition).

</div>

```python
# 0. imports
# pip install bitsandbytes
Expand Down

0 comments on commit 9eaea2e

Please sign in to comment.