Skip to content

Commit

Permalink
Support for gemma architecture models (#125)
Browse files Browse the repository at this point in the history
* added support for gemma architecture models

* using the model_surgery.Norm in tests

* fixed mockmodel/gemma-tiny

* fixed failing norm test

* fixed pytorch dependency

* fixed codecov secret
  • Loading branch information
levmckinney authored Jun 2, 2024
1 parent d64f28a commit 550644a
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pre-merge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ jobs:
pytest --cov=./ --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
directory: ./coverage/reports/
env_vars: OS,PYTHON
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ requires-python = ">=3.9"
keywords = ["nlp", "interpretability", "language-models", "explainable-ai"]
license = {text = "MIT License"}
dependencies = [
"accelerate",
"datasets",
"accelerate>=0.27.0",
"datasets>=2.17.1",
"plotly>=5.13.1",
"torchdata>=0.6.0",
"torch>=1.13.0",
"transformers>=4.28.1",
"torch>=2.0,!=2.3.0",
"transformers>=4.38.1",
"huggingface_hub>=0.16.4",
"simple-parsing>=0.1.4",
"flatten-dict>=0.4.1",
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def text_dataset(text_dataset_path: Path) -> Dataset:
"EleutherAI/gpt-neo-125M",
"facebook/opt-125m",
"mockmodel/llama-tiny",
"mockmodel/gemma-tiny",
"gpt2",
],
)
Expand All @@ -43,6 +44,15 @@ def random_small_model(request: str) -> tr.PreTrainedModel:
num_hidden_layers=4,
num_attention_heads=4,
)
elif small_model_name == "mockmodel/gemma-tiny":
config = tr.GemmaConfig(
vocab_size=32_000,
hidden_size=128,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
head_dim=32,
)
else:
config = tr.AutoConfig.from_pretrained(small_model_name)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_model_surgery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch as th
from transformers import PreTrainedModel, models
from transformers import PreTrainedModel

from tuned_lens import model_surgery

Expand All @@ -13,7 +13,7 @@ def test_get_final_layer_norm_raises(opt_random_model: PreTrainedModel):

def test_get_final_layer_norm(random_small_model: PreTrainedModel):
ln = model_surgery.get_final_norm(random_small_model)
assert isinstance(ln, (th.nn.LayerNorm, models.llama.modeling_llama.LlamaRMSNorm))
assert any(isinstance(ln, Norm) for Norm in model_surgery.Norm.__args__)


def test_get_layers_from_model(random_small_model: PreTrainedModel):
Expand Down
11 changes: 10 additions & 1 deletion tuned_lens/model_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def assign_key_path(model: T, key_path: str, value: Any) -> Generator[T, None, N


Model = Union[tr.PreTrainedModel, "tl.HookedTransformer"]
Norm = Union[th.nn.LayerNorm, models.llama.modeling_llama.LlamaRMSNorm, nn.Module]
Norm = Union[
th.nn.LayerNorm,
models.llama.modeling_llama.LlamaRMSNorm,
models.gemma.modeling_gemma.GemmaRMSNorm,
nn.Module,
]


def get_unembedding_matrix(model: Model) -> nn.Linear:
Expand Down Expand Up @@ -114,6 +119,8 @@ def get_final_norm(model: Model) -> Norm:
final_layer_norm = base_model.ln_f
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
final_layer_norm = base_model.norm
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
final_layer_norm = base_model.norm
else:
raise NotImplementedError(f"Unknown model type {type(base_model)}")

Expand Down Expand Up @@ -159,6 +166,8 @@ def get_transformer_layers(model: Model) -> tuple[str, th.nn.ModuleList]:
path_to_layers += ["h"]
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel):
path_to_layers += ["layers"]
elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel):
path_to_layers += ["layers"]
else:
raise NotImplementedError(f"Unknown model type {type(base_model)}")

Expand Down

0 comments on commit 550644a

Please sign in to comment.