Skip to content

Commit

Permalink
ignore special tokens in evals too
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Jan 31, 2025
1 parent 8c57675 commit 7593e4b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
10 changes: 8 additions & 2 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
def _mk_activations_store(
model: HookedRootModule,
cfg: CacheActivationsRunnerConfig,
override_dataset: Dataset | None = None,
) -> ActivationsStore:
"""
Internal method used in CacheActivationsRunner. Used to create a cached dataset
from a ActivationsStore.
"""
return ActivationsStore(
model=model,
dataset=cfg.dataset_path,
dataset=override_dataset or cfg.dataset_path,
streaming=cfg.streaming,
hook_name=cfg.hook_name,
hook_layer=cfg.hook_layer,
Expand All @@ -53,7 +54,11 @@ def _mk_activations_store(


class CacheActivationsRunner:
def __init__(self, cfg: CacheActivationsRunnerConfig):
def __init__(
self,
cfg: CacheActivationsRunnerConfig,
override_dataset: Dataset | None = None,
):
self.cfg = cfg
self.model: HookedRootModule = load_model(
model_class_name=self.cfg.model_class_name,
Expand All @@ -66,6 +71,7 @@ def __init__(self, cfg: CacheActivationsRunnerConfig):
self.activations_store = _mk_activations_store(
self.model,
self.cfg,
override_dataset=override_dataset,
)
self.context_size = self._get_sliced_context_size(
self.cfg.context_size, self.cfg.seqpos_slice
Expand Down
6 changes: 6 additions & 0 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,17 @@ def _run_and_log_evals(self):
self.cfg.wandb_log_frequency * self.cfg.eval_every_n_wandb_logs
) == 0:
self.sae.eval()
ignore_tokens = set()

Check warning on line 335 in sae_lens/training/sae_trainer.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/training/sae_trainer.py#L335

Added line #L335 was not covered by tests
if self.activations_store.exclude_special_tokens is not None:
ignore_tokens = set(

Check warning on line 337 in sae_lens/training/sae_trainer.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/training/sae_trainer.py#L337

Added line #L337 was not covered by tests
self.activations_store.exclude_special_tokens.tolist()
)
eval_metrics, _ = run_evals(
sae=self.sae,
activation_store=self.activations_store,
model=self.model,
eval_config=self.trainer_eval_config,
ignore_tokens=ignore_tokens,
model_kwargs=self.cfg.model_kwargs,
) # not calculating featurwise metrics here.

Expand Down
67 changes: 66 additions & 1 deletion tests/unit/training/test_cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _default_cfg(
context_size: int = 8,
dataset_num_rows: int = 128,
n_buffers: int = 4,
shuffle: bool = False,
**kwargs: Any,
) -> CacheActivationsRunnerConfig:
d_in = 512
Expand Down Expand Up @@ -64,7 +65,7 @@ def _default_cfg(
context_size=context_size,
###
d_in=d_in,
shuffle=False,
shuffle=shuffle,
prepend_bos=False,
device=device,
seed=42,
Expand Down Expand Up @@ -357,3 +358,67 @@ def test_cache_activations_runner_stores_token_ids(tmp_path: Path):
assert "token_ids" in dataset.features
assert dataset["token_ids"].shape[1] == cfg.context_size # type: ignore
assert dataset["blocks.0.hook_mlp_out"].shape[:2] == dataset["token_ids"].shape # type: ignore


def test_cache_activations_runner_shuffling(tmp_path: Path):
"""Test that when shuffle=True, activations and token IDs remain aligned after shuffling."""
# Create test dataset with arbitrary unique tokens
tokenizer = HookedTransformer.from_pretrained("gelu-1l").tokenizer
text = "".join(
[
" " + word[1:]
for word in tokenizer.vocab # type: ignore
if word[0] == "Ġ" and word[1:].isascii() and word.isalnum()
]
)
dataset = Dataset.from_list([{"text": text}])

# Create configs for unshuffled and shuffled versions
base_cfg = _default_cfg(
tmp_path / "base",
context_size=3,
batch_size=2,
dataset_num_rows=8,
shuffle=False,
)
shuffle_cfg = _default_cfg(
tmp_path / "shuffled",
context_size=3,
batch_size=2,
dataset_num_rows=8,
shuffle=True,
)

# Get unshuffled dataset
unshuffled_runner = CacheActivationsRunner(base_cfg, override_dataset=dataset)
unshuffled_ds = unshuffled_runner.run()
unshuffled_ds.set_format("torch")

# Get shuffled dataset
shuffled_runner = CacheActivationsRunner(shuffle_cfg, override_dataset=dataset)
shuffled_ds = shuffled_runner.run()
shuffled_ds.set_format("torch")

# Get activations and tokens
hook_name = base_cfg.hook_name
unshuffled_acts: torch.Tensor = unshuffled_ds[hook_name] # type: ignore
unshuffled_tokens: torch.Tensor = unshuffled_ds["token_ids"] # type: ignore
shuffled_acts: torch.Tensor = shuffled_ds[hook_name] # type: ignore
shuffled_tokens: torch.Tensor = shuffled_ds["token_ids"] # type: ignore

# Verify shapes are preserved
assert unshuffled_acts.shape == shuffled_acts.shape
assert unshuffled_tokens.shape == shuffled_tokens.shape

# Verify data is actually shuffled
assert not (unshuffled_acts == shuffled_acts).all()
assert not (unshuffled_tokens == shuffled_tokens).all()

# For each token in unshuffled, find its position in shuffled
# and verify the activations were moved together
for i in range(len(unshuffled_tokens)):
token = unshuffled_tokens[i]
# Find where this token went in shuffled version
shuffled_idx = torch.where(shuffled_tokens == token)[0][0]
# Verify activations moved with it
assert torch.allclose(unshuffled_acts[i], shuffled_acts[shuffled_idx])

0 comments on commit 7593e4b

Please sign in to comment.