Skip to content

Commit

Permalink
Use separate (larger) dataset for gram (and mean) matrices (#344)
Browse files Browse the repository at this point in the history
## Description
Allow a separate dataset to be used for the gram matrix computation than for the RIB basis computation.

I also allow using a tokenized dataset rather than untokenized dataset to skip the (kinda slow) tokenization.

Also added an option to store the computed gram matrix to a file!

## Motivation and Context
We noticed that the gram (PCA) dataset size is a lot more sensitive to amount of samples, and also a lot cheaper.

## How Has This Been Tested?
Did runs, and scaling plots. Added a test making sure this config option runs.

## Does this PR introduce a breaking change?
No. Not giving a gram_dataset defaults to using the same dataset as for the Cs.
  • Loading branch information
stefan-apollo authored Mar 8, 2024
1 parent 65dc3d2 commit 4613bad
Show file tree
Hide file tree
Showing 21 changed files with 2,723 additions and 133 deletions.
108 changes: 80 additions & 28 deletions rib/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
import yaml
from datasets import load_dataset as hf_load_dataset
from torch.utils.data import Dataset, TensorDataset
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer

from rib.log import logger

if TYPE_CHECKING:
from rib.rib_builder import RibBuildConfig

Expand Down Expand Up @@ -214,7 +217,7 @@ def create_modular_arithmetic_dataset(
return dataset_subset


def tokenize_dataset(
def prepare_dataset(
dataset: Dataset,
tokenizer: AutoTokenizer,
n_ctx: int,
Expand All @@ -241,23 +244,44 @@ def tokenize_dataset(
TensorDataset: The tokenized dataset.
"""
assert tokenizer.eos_token_id is not None, "Tokenizer must have an eos token id"
# May already be tokenized. use try except because not all datasets (in particular those in
# TestTokenizeDataset) have the features attribute.
try:
tokenized = "input_ids" in dataset.features.keys() # type: ignore
except AttributeError:
tokenized = False
# Tokenize all samples and merge them into one long list of tokens
all_tokens = []
for example in dataset: # type: ignore
tokens = tokenizer(example["text"])["input_ids"]
# Add the eos token to the end of each sample as was done in the original training
# https://github.com/EleutherAI/pythia/issues/123#issuecomment-1791136253
all_tokens.extend(tokens + [tokenizer.eos_token_id])

# There shouldn't be any padding tokens, so ensure that there are len(dataset) eos tokens
all_tokens: list[int] = []
for example in tqdm(dataset, desc="Processing dataset", total=len(dataset)): # type: ignore
if not tokenized:
tokens = tokenizer(example["text"])["input_ids"]
# Add the eos token to the end of each sample as was done in the original training
# https://github.com/EleutherAI/pythia/issues/123#issuecomment-1791136253
all_tokens.extend(tokens + [tokenizer.eos_token_id])
else:
# Time: This loop can take around 2min for 450M tokens (99% TinyStories)
tokens = example["input_ids"]
all_tokens.extend(tokens)
# Check number of eos tokens
len_dataset = len(dataset) # type: ignore
assert all_tokens.count(tokenizer.eos_token_id) == len_dataset, (
f"Number of eos tokens ({all_tokens.count(tokenizer.eos_token_id)}) does not match "
f"number of samples ({len_dataset})."
)
if not tokenized:
n_eos_tokens = all_tokens.count(tokenizer.eos_token_id)
# When we tokenize the dataset ourselves then len_dataset is the number of actual stories
# and thus the number of eos tokens should match len_dataset.
assert n_eos_tokens == len_dataset, (
f"Number of eos tokens ({all_tokens.count(tokenizer.eos_token_id)}) does not match "
f"number of samples ({len_dataset})."
)
# Note: You really should check the tokenized dataset you used, and make sure it behaves
# correctly. We can't really assert this kind of thing.
# TinyStories (apollo-research/sae-skeskinen-TinyStories-hf-tokenizer-gpt2): Confirmed with
# tokenizer.decode(all_tokens) and tokenizer.decode(tokens) that the chunks that get
# concatenated are indeed follow-ups of each other. <eos> tokens are correctly placed at the
# end of each story.

# Split the merged tokens into chunks that fit the context length
# Time: This line can take around 45s for 450M tokens (99% TinyStories)
raw_chunks = [all_tokens[i : i + n_ctx] for i in range(0, len(all_tokens), n_ctx)]

# Note that we ignore the final raw_chunk, as we get the label for the final token in a chunk
# from the subsequent chunk.
n_raw_chunks = len(raw_chunks) - 1
Expand All @@ -275,6 +299,7 @@ def tokenize_dataset(

chunks = [raw_chunks[i] for i in chunk_idxs]

# Time: This loop can take around 30s for 450M tokens (99% TinyStories)
all_labels: list[list[int]] = []
for i, chunk in enumerate(chunks):
# Get the label for the last token using the next chunk in raw_chunks
Expand Down Expand Up @@ -347,17 +372,21 @@ def create_hf_dataset(
# Sample n_samples from all documents in return_set
data_split = dataset_config.return_set

logger.info(f"Loading HuggingFace dataset {dataset_config.name} split {data_split}")
raw_dataset = hf_load_dataset(dataset_config.name, split=data_split)
logger.info(f"Loaded {len(raw_dataset)} documents from HuggingFace dataset")

tokenizer = AutoTokenizer.from_pretrained(dataset_config.tokenizer_name)
tokenizer.pad_token = tokenizer.eos_token
tokenized_dataset = tokenize_dataset(
logger.info(f"Tokenizing HuggingFace dataset with tokenizer {dataset_config.tokenizer_name}")
tokenized_dataset = prepare_dataset(
dataset=raw_dataset,
tokenizer=tokenizer,
n_ctx=n_ctx,
n_samples=dataset_config.n_samples,
seed=dataset_config.seed,
)
logger.info(f"Tokenized {len(tokenized_dataset)} samples from HuggingFace dataset")
return tokenized_dataset


Expand Down Expand Up @@ -425,31 +454,25 @@ def load_dataset(
return create_block_vector_dataset(dataset_config=dataset_config)


def load_model_and_dataset_from_rib_config(
def load_model(
rib_config: "RibBuildConfig",
device: str,
dtype: torch.dtype,
dataset_config: Optional[DatasetConfig] = None,
node_layers: Optional[list[str]] = None,
) -> Tuple[Union[SequentialTransformer, MLP], Dataset]:
"""Loads the model and dataset for a rib build based on the config.
Combines both model and dataset loading in one function as the dataset conditionally needs
extra arguments depending on the dataset type.
) -> Union[SequentialTransformer, MLP]:
"""Loads the model for a rib build based on the config.
Args:
rib_config (RibBuildConfig): The rib build config.
device (str): The device to use for the model.
dtype (torch.dtype): The dtype to use for the model.
dataset_config (Optional[DatasetConfig]): The dataset config to use. If None, uses the
dataset config from the rib_config.
node_layers (Optional[list[str]]): The node layers to use for the model. If None, uses the
node layers from the rib_config. Note that changing the sections in the model has no
effect on the model computation, so we allow specifying any node_layers for the
convenience of hooking different sections of the model.
Returns:
tuple[Union[SequentialTransformer, MLP], Dataset]: The model and dataset.
Union[SequentialTransformer, MLP]: The model.
"""
model: Union[SequentialTransformer, MLP]
if rib_config.mlp_path is not None or rib_config.modular_mlp_config is not None:
Expand Down Expand Up @@ -481,12 +504,41 @@ def load_model_and_dataset_from_rib_config(
device=device,
)
model.eval()
dataset_config = dataset_config or rib_config.dataset
return model


def load_model_and_dataset_from_rib_config(
rib_config: "RibBuildConfig",
device: str,
dtype: torch.dtype,
dataset_config: Optional[DatasetConfig] = None,
node_layers: Optional[list[str]] = None,
) -> Tuple[Union[SequentialTransformer, MLP], Dataset]:
"""Loads the model and dataset for a rib build based on the config.
Combines both model and dataset loading in one function as the dataset conditionally needs
extra arguments depending on the dataset type.
Args:
rib_config (RibBuildConfig): The rib build config.
device (str): The device to use for the model.
dtype (torch.dtype): The dtype to use for the model.
dataset_config (Optional[DatasetConfig]): The dataset config to use. If None, uses the
dataset config from the rib_config.
node_layers (Optional[list[str]]): The node layers to use for the model. If None, uses the
node layers from the rib_config. Note that changing the sections in the model has no
effect on the model computation, so we allow specifying any node_layers for the
convenience of hooking different sections of the model.
Returns:
tuple[Union[SequentialTransformer, MLP], Dataset]: The model and dataset.
"""
dataset_config = dataset_config or rib_config.dataset
assert dataset_config is not None, "dataset_config or rib_config.dataset must be provided"
model = load_model(rib_config, device, dtype, node_layers)
dataset = load_dataset(
dataset_config=dataset_config,
dataset_config,
model_n_ctx=model.cfg.n_ctx if isinstance(model, SequentialTransformer) else None,
tlens_model_path=rib_config.tlens_model_path,
)

return model, dataset
Loading

0 comments on commit 4613bad

Please sign in to comment.