-
Notifications
You must be signed in to change notification settings - Fork 357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ReFT (LoReFT, NoReFT, DiReFT) #705
Conversation
@calpt Thanks for the PR! I took a quick look, and it looks promising. Here are two minor questions:
Does this mean if I want united weights among prefix and suffix, I will create two adaptors and set
Thanks! |
@frankaging Thanks for looking over this! Re your questions: Re 1: In the current implementation, one or two modules per layer will be created depending on whether the adapters/src/adapters/methods/reft.py Lines 50 to 62 in 25f1d1c
This makes it very easy for a user to to tie or not tie weights when adding a single Reft adapter, e.g.: from adapters import AutoAdapterModel, ReftConfig
model = AutoAdapterModel.from_pretrained("...")
config = ReftConfig(
layers="all", prefix_positions=1, suffix_positions=1, r=1,
tied_weights=True # set to True or False to share weights
)
model.add_adapter("my_reft", config=config)
model.set_active_adapters("my_reft") Re 2: Currently, the Reft implementation always assumes interventions are added to the residual stream as you explained, since this is the method proposed in the paper. This is done via a PyTorch hook here: adapters/src/adapters/methods/reft.py Lines 162 to 169 in 25f1d1c
While no other intervention points are added for now, we can easily extend with similar hooks for other intervention points where it makes sense to do so. Thanks again for looking over this! Please let us know if you have any suggestions or ideas what we should add or change for the first version! |
@calpt Thanks for your responses! It makes sense to me. Will the hook work out of the box for accelerated training (e.g., deepspeed, etc..)? Any existing tests on this? Thanks! |
@frankaging No extensive tests for training yet at this point. Deepspeed support of this library is unfortunately flaky in general and not really a focus at the moment, but using e.g. torch distributed or HF Accelerate should work in the end. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Just some small comments and questions
src/adapters/model_mixin.py
Outdated
@@ -968,6 +972,17 @@ def forward_context(self, context: ForwardContext, *args, **kwargs): | |||
if hasattr(self.base_model, "prefix_tuning"): | |||
context.prefix_states = self.base_model.prefix_tuning(*args, **kwargs) | |||
|
|||
# TODO this does not support padding on the left |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to leave this TODO open?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added, pls check if this is correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented on some minor things; everything else looks good & correctly implemented to me
Once the open comments are resolved & left padding is implemented this is ready to merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. This is ready to merge
Thanks! I've added some quick training results on GLUE tasks in the description, based on that the implementation looks good. @frankaging re distributed training: I've tested it to work with torch distributed & HF Accelerate via the Trainer class, e.g. for GLUE:
|
@calpt Thanks! It's great to see this approach works for different kinds of parallel training! I was re-looking into that orthogonal matrix initialization thing (i.e., I was referencing the PEFT repo ticket and asking whether we should remove a redundant init), and I found that for some cases, removing that init step might cause unstable results. Have you look in to this again by doing some tests on your side? Thanks. |
Interesting, I haven't tested this specifically. From looking at the code, it makes sense the orthogonal init is redundant, would you suggest we re-add it still? |
This PR integrates multiple ReFT variants as new adapter methods. Paper: https://arxiv.org/pdf/2404.03592 Original code: https://github.com/stanfordnlp/pyreft ## Changes - Add ReFT module implementation via `ReftLayer`, integrated into all models supported by Adapters. Integration via `init_reft()` method & Pytorch hook. - Add new `ReftConfig` as base config class with three default instances: `LoReftConfig`, `NoReftConfig` and `DiReftConfig`. - Method documentation can be found here: https://github.com/adapter-hub/adapters/blob/6c19ea06c143621a735226e477bf772068e55be3/docs/methods.md#reft ## Compatibility Tested that Pyreft & Adapters produce the same outputs on inference by converting Pyreft checkpoints to Adapters checkpoints (tested settings: LoReft, NoReft, DiReft, weight tying, prefix, suffix, rank, mostly using roberta-base). Script for testing & checkpoint conversion here: https://github.com/calpt/pyreft/blob/main/compatibility.py. ## Evaluation Roberta-base with LoReFT on GLUE, using hyperparameters similar to the paper: Task | Score --- | --- Cola (Matthews Corr.) | 53.95 MNLI (Acc.) | 83.23 MRPC (F1) | 91.70 QNLI (Acc.) | 90.94 QQP (Acc.) | 86.82 RTE (Acc.) | 76.53 SST-2 (Acc.) | 93.81 STS-B (Spearmanr) | 88.99 ## Todos - [x] Modeling implementations - [x] Add test methods - [x] Make all checks passing - [x] Add documentation - [x] Make sure implementation produces same outputs as original code - [x] Sanity check training runs
This PR integrates multiple ReFT variants as new adapter methods.
Paper: https://arxiv.org/pdf/2404.03592
Original code: https://github.com/stanfordnlp/pyreft
Changes
ReftLayer
, integrated into all models supported by Adapters. Integration viainit_reft()
method & Pytorch hook.ReftConfig
as base config class with three default instances:LoReftConfig
,NoReftConfig
andDiReftConfig
.Compatibility
Tested that Pyreft & Adapters produce the same outputs on inference by converting Pyreft checkpoints to Adapters checkpoints (tested settings: LoReft, NoReft, DiReft, weight tying, prefix, suffix, rank, mostly using roberta-base).
Script for testing & checkpoint conversion here: https://github.com/calpt/pyreft/blob/main/compatibility.py.
Evaluation
Roberta-base with LoReFT on GLUE, using hyperparameters similar to the paper:
Todos