Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ReFT (LoReFT, NoReFT, DiReFT) #705

Merged
merged 20 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ Currently, adapters integrates all architectures and methods listed below:
| UniPELT | [Mao et al. (2022)](https://arxiv.org/pdf/2110.07577.pdf) | [Docs](https://docs.adapterhub.ml/method_combinations.html#unipelt) |
| Prompt Tuning | [Lester et al. (2021)](https://aclanthology.org/2021.emnlp-main.243/) | [Docs](https://docs.adapterhub.ml/methods.html#prompt-tuning) |
| QLoRA | [Dettmers et al. (2023)](https://arxiv.org/pdf/2305.14314.pdf) | [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) |
| ReFT | [Wu et al. (2024)](https://arxiv.org/pdf/2404.03592) | [Docs](https://docs.adapterhub.ml/methods.html#reft) |

## Supported Models

Expand Down
17 changes: 17 additions & 0 deletions docs/classes/adapter_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ PromptTuningConfig
:members:
:inherited-members: Mapping


ReFT
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.ReftConfig
:members:
:inherited-members: Mapping

.. autoclass:: adapters.LoReftConfig
:members:

.. autoclass:: adapters.NoReftConfig
:members:

.. autoclass:: adapters.DiReftConfig
:members:

Combined configurations
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
47 changes: 47 additions & 0 deletions docs/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,50 @@ model.add_adapter("dummy", config=config)
_Papers:_
- [The Power of Scale for Parameter-Efficient Prompt Tuning](https://aclanthology.org/2021.emnlp-main.243/) (Lester et al., 2021)

## ReFT

_Configuration class_: [`ReftConfig`](adapters.ReftConfig)

Representation Fine-Tuning (ReFT), as first proposed by [Wu et al. (2024)](https://arxiv.org/pdf/2404.03592), leverages so-called interventions to adapt the pre-trained representations of a language model.
Within the context of ReFT, these interventions can intuitively be thought of as adapter modules placed after each Transformer layer.
In the general form, an intervention function $\Phi$ can thus be defined as follows:

$$
\Phi(h) = h + R^T (W h + b - R h)
$$

Here, $R \in \mathbb{R}^{r \times d}$ and $W \in \mathbb{R}^{r \times d}$ are low-rank matrices of rank $r$.
$h$ is the layer output hidden state at a single sequence position, i.e. interventions can be applied independently at each position.

Based on this general form, the ReFT paper proposes multiple instantiations of ReFT methods supported by _Adapters_:

- **LoReFT** enforces orthogonality of rows in $R$. Defined via [`LoReftConfig`](adapters.LoReftConfig) or via the `orthogonality` attribute as in the following example:
```python
config = ReftConfig(
layers="all", prefix_positions=3, suffix_positions=0, r=1, orthogonality=True
) # equivalent to LoreftConfig()
```

- **NoReFT** does not enforce orthogonality in $R$. Defined via [`NoReftConfig`](adapters.NoReftConfig) or equivalently:
```python
config = ReftConfig(
layers="all", prefix_positions=3, suffix_positions=0, r=1, orthogonality=False
) # equivalent to NoreftConfig()
```

- **DiReFT** does not enforce orthogonality in $R$ and additionally removes subtraction of $R h$ in the intervention, Defined via [`DiReftConfig`](adapters.DiReftConfig) or equivalently:
```python
config = ReftConfig(
layers="all", prefix_positions=3, suffix_positions=0, r=1, orthogonality=False, subtract_projection=False
) # equivalent to DireftConfig()
```

In addition, _Adapters_ supports configuring multiple hyperparameters tuned in the ReFT paper in `ReftConfig`, including:
- `prefix_positions`: number of prefix positions
- `suffix_positions`: number of suffix positions
- `layers`: The layers to intervene on. This can either be `"all"` or a list of layer ids
- `tied_weights`: whether to tie parameters between prefixes and suffixes

_Papers:_

* [ReFT: Representation Finetuning for Language Models](https://arxiv.org/pdf/2404.03592) (Wu et al., 2024)
44 changes: 22 additions & 22 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,29 @@ The table below further shows which model architectures support which adaptation
E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters.
```

| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block | Prompt<br> Tuning |
| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block | Prompt<br> Tuning | ReFT |
| --------------------------------------- | -| - | - | - | - | - | - |- |
| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ |
| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | |
| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | |
| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [MT5](classes/models/mt5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | ✅ |
| [BERT-Generation](classes/models/bert-generation.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [CLIP](classes/models/clip.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ |
| [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [DistilBERT](classes/models/distilbert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Electra](classes/models/electra.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | | (*) |
| [GPT-2](classes/models/gpt2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [MT5](classes/models/mt5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |

(*) If the used encoder and decoder model class are supported.

Expand Down
5 changes: 4 additions & 1 deletion docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ Identifiers and configuration classes are explained in more detail in the [next
| `ia3` | `IA3Config()` | [IA³](methods.html#ia-3) |
| `mam` | `MAMConfig()` | [Mix-and-Match Adapters](method_combinations.html#mix-and-match-adapters) |
| `unipelt` | `UniPELTConfig()` | [UniPELT](method_combinations.html#unipelt) |
| `prompt_tuning` | `PromptTuningConfig()` | [Prompt Tuning](methods.html#prompt-tuning)
| `prompt_tuning` | `PromptTuningConfig()` | [Prompt Tuning](methods.html#prompt-tuning) |
| `loreft` | `LoReftConfig()` | [ReFT](methods.html#reft) |
| `noreft` | `NoReftConfig()` | [ReFT](methods.html#reft) |
| `direft` | `DiReftConfig()` | [ReFT](methods.html#reft) |

## Configuration

Expand Down
8 changes: 8 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,20 @@
"CompacterConfig",
"CompacterPlusPlusConfig",
"ConfigUnion",
"DiReftConfig",
"DoubleSeqBnConfig",
"DoubleSeqBnInvConfig",
"DynamicAdapterFusionConfig",
"IA3Config",
"LoRAConfig",
"LoReftConfig",
"MAMConfig",
"ModelAdaptersConfig",
"NoReftConfig",
"ParBnConfig",
"PrefixTuningConfig",
"PromptTuningConfig",
"ReftConfig",
"SeqBnConfig",
"SeqBnInvConfig",
"StaticAdapterFusionConfig",
Expand Down Expand Up @@ -154,16 +158,20 @@
CompacterConfig,
CompacterPlusPlusConfig,
ConfigUnion,
DiReftConfig,
DoubleSeqBnConfig,
DoubleSeqBnInvConfig,
DynamicAdapterFusionConfig,
IA3Config,
LoRAConfig,
LoReftConfig,
MAMConfig,
ModelAdaptersConfig,
NoReftConfig,
ParBnConfig,
PrefixTuningConfig,
PromptTuningConfig,
ReftConfig,
SeqBnConfig,
SeqBnInvConfig,
StaticAdapterFusionConfig,
Expand Down
84 changes: 83 additions & 1 deletion src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from dataclasses import FrozenInstanceError, asdict, dataclass, field, replace
from typing import List, Optional, Union
from typing import List, Literal, Optional, Union

from ..utils import resolve_adapter_config

Expand Down Expand Up @@ -86,6 +86,8 @@ def _get_config_class(config_dict):
cls_new = ConfigUnion
elif architecture == "prompt_tuning":
cls_new = PromptTuningConfig
elif architecture == "reft":
cls_new = ReftConfig
else:
cls_new = BnConfig

Expand Down Expand Up @@ -497,6 +499,83 @@ class IA3Config(LoRAConfig):
use_gating: bool = False


@dataclass(eq=False)
class ReftConfig(AdapterConfig):
"""
Base class for Representation Fine-Tuning (ReFT) methods proposed in Wu et al. (2024). See https://arxiv.org/pdf/2404.03592.
ReFT methods have in common that they add "interventions" after selected model layers and at selected sequence positions to adapt the representations produced by module outputs.

Args:
layers (Union[Literal["all"], List[int]]): The IDs of the layers where interventions should be added.
If "all", interventions are added after all layers (default).
prefix_positions (int): The number of prefix positions to add interventions to.
suffix_positions (int): The number of suffix positions to add interventions to.
r (int): The rank of the intervention layer.
orthogonality (bool): If True, enforce an orthogonality constraint for the projection matrix.
tied_weights (bool): If True, share intervention parameters between prefix and suffix positions in each layer.
subtract_projection (bool): If True, subtract the projection of the input.
dropout (float): The dropout rate used in the intervention layer.
non_linearity (str): The activation function used in the intervention layer.
"""

layers: Union[Literal["all"], List[int]]
prefix_positions: int
suffix_positions: int
r: int
orthogonality: bool
tied_weights: bool = False
subtract_projection = True
dropout: float = 0.05
non_linearity: Optional[str] = None

architecture: str = "reft"

output_reft: bool = True


@dataclass(eq=False)
class LoReftConfig(ReftConfig):
"""
Low-Rank Linear Subspace ReFT method proposed in Wu et al. (2024). See https://arxiv.org/pdf/2404.03592.
"""

layers: Union[Literal["all"], List[int]] = "all"
prefix_positions: int = 3
suffix_positions: int = 0
r: int = 1
orthogonality: bool = True
tied_weights: bool = False


@dataclass(eq=False)
class NoReftConfig(ReftConfig):
"""
Variation of LoReft without orthogonality constraint.
"""

layers: Union[Literal["all"], List[int]] = "all"
prefix_positions: int = 3
suffix_positions: int = 0
r: int = 1
orthogonality: bool = False
tied_weights: bool = False


@dataclass(eq=False)
class DiReftConfig(ReftConfig):
"""
Variation of LoReft without orthogonality constraint and projection subtraction as proposed in Wu et al. (2024). See https://arxiv.org/pdf/2404.03592.
"""

layers: Union[Literal["all"], List[int]] = "all"
prefix_positions: int = 3
suffix_positions: int = 0
r: int = 1
orthogonality: bool = False
tied_weights: bool = False
subtract_projection = False


class ConfigUnion(AdapterConfig):
"""
Composes multiple adaptation method configurations into one. This class can be used to define complex adaptation
Expand Down Expand Up @@ -650,6 +729,9 @@ def __init__(
"prompt_tuning": PromptTuningConfig(),
"lora": LoRAConfig(),
"ia3": IA3Config(),
"loreft": LoReftConfig(),
"noreft": NoReftConfig(),
"direft": DiReftConfig(),
"mam": MAMConfig(),
"unipelt": UniPELTConfig(),
}
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def filter_func(self, adapter_name):
or ".prefix_tunings.{}.".format(adapter_name) in x
or ".prefix_gates.{}.".format(adapter_name) in x
or ".loras.{}.".format(adapter_name) in x
or ".refts.{}.".format(adapter_name) in x
or ".prompt_tunings.{}.".format(adapter_name) in x
)

Expand Down Expand Up @@ -393,6 +394,7 @@ def rename_func(self, old_name, new_name):
.replace(".prefix_tunings.{}.".format(old_name), ".prefix_tunings.{}.".format(new_name))
.replace(".prefix_gates.{}.".format(old_name), ".prefix_gates.{}.".format(new_name))
.replace(".loras.{}.".format(old_name), ".loras.{}.".format(new_name))
.replace(".refts.{}.".format(old_name), ".refts.{}.".format(new_name))
)

def save_to_state_dict(self, name: str):
Expand Down Expand Up @@ -446,6 +448,8 @@ def save(self, save_directory, name, meta_dict=None):

adapter_config = self.model.adapters_config.get(name)

self.model.apply_to_adapter_layers(lambda _, layer: layer.pre_save_adapters())

config_dict = build_full_config(
adapter_config,
self.model.config,
Expand Down
Loading
Loading