Skip to content

Commit

Permalink
Fix review issues
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jun 8, 2024
1 parent 6c19ea0 commit f94dbeb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/adapters/methods/adapter_layer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def freeze_adapter(self, adapter_name: str, freeze: bool = True):
adapter_name (str): The name of the adapter to freeze/ unfreeze.
freeze (bool, optional): Whether to freeze the adapter. Defaults to True.
"""
if adapter_name in self.refts:
self.refts[adapter_name].train(not freeze)
for param in self.refts[adapter_name].parameters():
if adapter_name in self.adapter_modules:
self.adapter_modules[adapter_name].train(not freeze)
for param in self.adapter_modules[adapter_name].parameters():
param.requires_grad = not freeze

def get_adapter(self, adapter_name: str) -> nn.Module:
Expand Down
3 changes: 2 additions & 1 deletion tests/methods/test_reft.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from adapters import DiReftConfig, LoReftConfig
from adapters import DiReftConfig, LoReftConfig, NoReftConfig
from transformers.testing_utils import require_torch

from .base import AdapterMethodBaseTestMixin
Expand All @@ -8,6 +8,7 @@
class ReftTestMixin(AdapterMethodBaseTestMixin):
reft_configs_to_test = [
(LoReftConfig(), ["refts.{name}."]),
(NoReftConfig(prefix_positions=2, suffix_positions=2), ["refts.{name}."]),
(DiReftConfig(tied_weights=True), ["refts.{name}."]),
]

Expand Down

0 comments on commit f94dbeb

Please sign in to comment.