Skip to content

Commit

Permalink
Merge pull request #848 from TransformerLensOrg/dev
Browse files Browse the repository at this point in the history
Release 2.13.0
  • Loading branch information
bryce13950 authored Feb 5, 2025
2 parents db0f191 + 0c78adb commit 53dee84
Show file tree
Hide file tree
Showing 14 changed files with 872 additions and 385 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
- name: Unit Test
run: make unit-test
env:
HF_TOKEN: ${{ vars.HF_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Acceptance Test
run: make acceptance-test
- name: Build check
Expand Down Expand Up @@ -109,11 +109,11 @@ jobs:
- name: Test Suite with Coverage Report
run: make coverage-report-test
env:
HF_TOKEN: ${{ vars.HF_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Build check
run: poetry build
- name: Upload Coverage Report Artifact
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: test-coverage
path: htmlcov
Expand Down Expand Up @@ -192,16 +192,16 @@ jobs:
- name: Install dependencies
run: poetry install --with docs
- name: Download Test Coverage Artifact
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: test-coverage
path: docs/source/_static/coverage
- name: Build Docs
run: poetry run build-docs
env:
HF_TOKEN: ${{ vars.HF_TOKEN }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Upload Docs Artifact
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: documentation
path: docs/build
Expand All @@ -215,7 +215,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Download Docs Artifact
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: documentation
path: docs/build
Expand Down
449 changes: 449 additions & 0 deletions demos/LLaVA.ipynb

Large diffs are not rendered by default.

307 changes: 65 additions & 242 deletions poetry.lock

Large diffs are not rendered by default.

11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@
python=">=3.8,<4.0"
rich=">=12.6.0"
sentencepiece="*"
torch=[
{platform="!=linux", version=">=1.10,!=2.0,!=2.1.0"}, # Pin >=2.1.1 on Apple devices due to known MPS errors on 2.1.0
{platform="linux", version=">=1.10"}, # We can use any torch version on Linux (e.g colab)
]
torch=">=2.2,<2.5"
tqdm=">=4.64.1"
transformers=[
{version=">=4.37", python=">=3.8,<3.9"},
{version=">=4.41,<4.42", python=">=3.9,<4"},
]
transformers=">=4.43"
typing-extensions="*"
wandb=">=0.13.5"
typeguard = "^4.2"
transformers-stream-generator = "^0.0.5"

[tool.poetry.group]
[tool.poetry.group.dev.dependencies]
Expand Down
45 changes: 34 additions & 11 deletions tests/acceptance/test_hooked_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def test_relative_attention_bias(our_model, huggingface_model, hello_world_token

embed_out = huggingface_embed(hello_world_tokens)

huggingface_attn_out = huggingface_attn(embed_out)[0]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
our_attn_out = our_attn(embed_out, embed_out, embed_out, position_bias=our_bias)

assert_close(our_attn_out, huggingface_attn_out, rtol=7.4e-4, atol=1e-5)
Expand All @@ -139,7 +140,8 @@ def test_relative_attention_layer(our_model, huggingface_model, hello_world_toke
resid_norm = our_block.ln1(resid)
our_out = resid + our_block.attn(resid_norm, resid_norm, resid_norm, position_bias=our_bias)

hf_out = hf_block(resid)[0]
cache_position = torch.arange(input_len)
hf_out = hf_block(resid, cache_position=cache_position)[0]
assert_close(our_out, hf_out, rtol=1.3e-6, atol=4e-5)


Expand All @@ -151,7 +153,10 @@ def test_attention(our_model, huggingface_model, hello_world_tokens):
our_attn = our_model.encoder[1].attn

our_attn_out = our_attn(embed_out, embed_out, embed_out)
huggingface_attn_out = huggingface_attn(embed_out)[0]

input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]

assert_close(our_attn_out, huggingface_attn_out, rtol=5e-4, atol=1e-5)

Expand All @@ -164,7 +169,10 @@ def test_decoder_attention(our_model, huggingface_model, hello_world_tokens):
our_attn = our_model.decoder[1].attn

our_attn_out = our_attn(embed_out, embed_out, embed_out)
huggingface_attn_out = huggingface_attn(embed_out)[0]

input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
assert_close(our_attn_out, huggingface_attn_out, rtol=3e-4, atol=1e-5)


Expand All @@ -177,7 +185,9 @@ def test_attention_layer(our_model, huggingface_model, hello_world_tokens):
norm_embed = our_model.encoder[1].ln1(embed_out)
our_attn_out = our_attn(norm_embed, norm_embed, norm_embed) + embed_out

huggingface_attn_out = huggingface_attn(embed_out)[0]
input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
assert_close(our_attn_out, huggingface_attn_out, rtol=2e-4, atol=1e-5)


Expand All @@ -190,7 +200,9 @@ def test_decoder_attention_layer(our_model, huggingface_model, hello_world_token
norm_embed = our_model.decoder[1].ln1(embed_out)
our_attn_out = our_attn(norm_embed, norm_embed, norm_embed) + embed_out

huggingface_attn_out = huggingface_attn(embed_out)[0]
input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
assert_close(our_attn_out, huggingface_attn_out, rtol=3e-4, atol=4e-5)


Expand All @@ -203,7 +215,7 @@ def test_cross_attention(our_model, huggingface_model, hello_world_tokens, decod

our_cross_attn_out = our_cross_attn(decoder_hidden, encoder_hidden, encoder_hidden)
huggingface_cross_attn_out = huggingface_cross_attn(
decoder_hidden, key_value_states=encoder_hidden
decoder_hidden, key_value_states=encoder_hidden, cache_position=encoder_hidden
)[0]
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5)

Expand All @@ -221,7 +233,9 @@ def test_cross_attention_layer(our_model, huggingface_model, hello_world_tokens,
our_layer.cross_attn(our_layer.ln2(decoder_hidden), encoder_hidden, encoder_hidden)
+ decoder_hidden
)
huggingface_cross_attn_out = hf_layer(decoder_hidden, key_value_states=encoder_hidden)[0]
huggingface_cross_attn_out = hf_layer(
decoder_hidden, key_value_states=encoder_hidden, cache_position=encoder_hidden
)[0]
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5)


Expand All @@ -232,7 +246,9 @@ def test_encoder_block(our_model, huggingface_model, hello_world_tokens):

embed_out = huggingface_embed(hello_world_tokens)

hf_out = huggingface_block(embed_out)[0]
input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
hf_out = huggingface_block(embed_out, cache_position=cache_position)[0]
our_out = our_block(embed_out)

assert_close(our_out, hf_out, rtol=2e-4, atol=2e-5)
Expand All @@ -244,10 +260,17 @@ def test_decoder_block(our_model, huggingface_model, hello_world_tokens, decoder
our_block = our_model.decoder[1]

encoder_hidden = huggingface_model.encoder(hello_world_tokens)[0]
decoder_hidden = huggingface_model.decoder.block[0](huggingface_embed(decoder_input_ids))[0]

input_len = decoder_input_ids.shape[1]
cache_position = torch.arange(input_len)
decoder_hidden = huggingface_model.decoder.block[0](
huggingface_embed(decoder_input_ids), cache_position=cache_position
)[0]

our_out = our_block(decoder_hidden, encoder_hidden_states=encoder_hidden)
hf_out = huggingface_block(decoder_hidden, encoder_hidden_states=encoder_hidden)[0]
hf_out = huggingface_block(
decoder_hidden, encoder_hidden_states=encoder_hidden, cache_position=encoder_hidden
)[0]

assert_close(hf_out, our_out, rtol=2e-4, atol=2e-5)

Expand Down
74 changes: 74 additions & 0 deletions tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,77 @@ def test_all_pythia_models_exist():
f"Could not download model '{model}' from Huggingface."
" Maybe the name was changed or the model has been removed."
)


@pytest.mark.parametrize(
"input_type,return_type",
[
("str", "input"),
("str", "str"),
("str", "tokens"),
("str", "embeds"),
("tokens", "input"),
("tokens", "str"),
("tokens", "tokens"),
("tokens", "embeds"),
("embeds", "input"),
("embeds", "str"),
("embeds", "tokens"),
("embeds", "embeds"),
],
)
def test_different_inputs_for_generation(
input_type, return_type, print_output=False, max_new_tokens=3
):
from typing import List

device = "cuda" if torch.cuda.is_available() else "cpu"
hooked_llm = HookedTransformer.from_pretrained("gpt2", device=device)

hooked_llm.eval()
for text_input in [
"What is the meaning of life?",
["AI will destroy world", "AI will save us"],
]:
is_batched = False if isinstance(text_input, str) else True

tokens_input = hooked_llm.to_tokens(text_input)
embeddings_input = hooked_llm.embed(tokens_input)

if input_type == "str":
model_input = text_input
elif input_type == "tokens":
model_input = tokens_input
elif input_type == "embeds":
model_input = embeddings_input
else:
raise ValueError(f"Unknown input_type: {input_type}")

output = hooked_llm.generate(
input=model_input, max_new_tokens=max_new_tokens, return_type=return_type, verbose=False
)

if return_type == "str" or (return_type == "input" and input_type == "str"):
if is_batched:
assert isinstance(output, List), f"Expected list output but got {type(output)}"
assert isinstance(
output[0], str
), f"Expected list of strings but got list of {type(output[0])}"
else:
assert isinstance(output, str), f"Expected string output but got {type(output)}"
elif return_type == "tokens" or (return_type == "input" and input_type == "tokens"):
assert isinstance(
output, torch.Tensor
), f"Expected tensor output but got {type(output)}"
assert output.ndim == 2, f"Expected 2D tensor but got {output.ndim}D"
elif return_type == "embeds" or (return_type == "input" and input_type == "embeds"):
assert isinstance(
output, torch.Tensor
), f"Expected tensor output but got {type(output)}"
assert output.ndim == 3, f"Expected 3D tensor but got {output.ndim}D"

if print_output:
print(f"Input type: {input_type}, return type: {return_type}, output:\n{output}")

if print_output:
print()
4 changes: 3 additions & 1 deletion tests/integration/test_match_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test_compare_huggingface_attention_match_local_implementation(self, model_na
past_kv_cache_entry=None,
attention_mask=None,
)
hf_out, _ = hf_model.transformer.h[layer_n].attn(hidden_states=input)
hf_out, _, _ = hf_model.transformer.h[layer_n].attn(
hidden_states=input, output_attentions=True
)

assert torch.sum(tl_out == hf_out) == math.prod(tl_out.shape)
4 changes: 2 additions & 2 deletions transformer_lens/HookedEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
if tokenizer is not None:
self.tokenizer = tokenizer
elif self.cfg.tokenizer_name is not None:
huggingface_token = os.environ.get("HF_TOKEN", None)
huggingface_token = os.environ.get("HF_TOKEN", "")
self.tokenizer = AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
token=huggingface_token,
token=huggingface_token if len(huggingface_token) > 0 else None,
)
else:
self.tokenizer = None
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/HookedEncoderDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
if tokenizer is not None:
self.tokenizer = tokenizer
elif self.cfg.tokenizer_name is not None:
huggingface_token = os.environ.get("HF_TOKEN", None)
huggingface_token = os.environ.get("HF_TOKEN", "")
self.tokenizer = AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
token=huggingface_token,
token=huggingface_token if len(huggingface_token) > 0 else None,
)
else:
self.tokenizer = None
Expand Down
Loading

0 comments on commit 53dee84

Please sign in to comment.