Skip to content

Commit

Permalink
Positional embeddings are taken into account when generating, make se…
Browse files Browse the repository at this point in the history
…parate functions get_pos_offset and get_residual
  • Loading branch information
zazamrykh committed Dec 26, 2024
1 parent bbd664c commit 3466104
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 60 deletions.
173 changes: 114 additions & 59 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,91 @@ def check_hooks_to_add(
self.cfg.use_attn_in
), f"Cannot add hook {hook_point_name} if use_attn_in is False"

def get_pos_offset(self, past_kv_cache, batch_size):
# If we're doing caching, then we reuse keys and values from previous runs, as that's the
# only way that past activations will affect the final logits. The cache contains those so
# we don't need to recompute them. This is useful for generating text. As we have absolute
# positional encodings, to implement this we have a `pos_offset` variable, defaulting to
# zero, which says to offset which positional encodings are used (cached keys and values
# were calculated with their own positional encodings).
if past_kv_cache is None:
pos_offset = 0
else:
(
cached_batch_size,
cache_ctx_length,
num_heads_in_cache,
d_head_in_cache,
) = past_kv_cache[0].past_keys.shape
assert cached_batch_size == batch_size
if self.cfg.n_key_value_heads is None:
assert num_heads_in_cache == self.cfg.n_heads
else:
assert num_heads_in_cache == self.cfg.n_key_value_heads
assert d_head_in_cache == self.cfg.d_head
pos_offset = cache_ctx_length
return pos_offset

def get_residual(
self,
embed,
pos_offset,
prepend_bos=USE_DEFAULT_VALUE,
attention_mask=None,
tokens=None,
return_shortformer_pos_embed=True,
device=None,
):
if device is None:
device = devices.get_device_for_block_index(0, self.cfg)

if tokens is None:
# Because tokens only need for defining batch size and sequence length, we can simply synthesize them
tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device)

if attention_mask is None:
# If the padding side is left or we are using caching, we need to compute the attention
# mask for the adjustment of absolute positional embeddings and attention masking so
# that pad tokens are not attended.
if prepend_bos is USE_DEFAULT_VALUE:
prepend_bos = self.cfg.default_prepend_bos
attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos).to(
device
)

if self.cfg.positional_embedding_type == "standard":
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed + pos_embed # [batch, pos, d_model]
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "shortformer":
# If we're using shortformer style attention, we don't add the positional embedding to
# the residual stream. See HookedTransformerConfig for details
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed
shortformer_pos_embed = pos_embed
elif self.cfg.positional_embedding_type == "rotary":
# Rotary doesn't use positional embeddings, instead they're applied when dot producting
# keys and queries. See HookedTransformerConfig for details
residual = embed
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "alibi":
# ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
residual = embed
shortformer_pos_embed = None
else:
raise ValueError(
f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
)

if return_shortformer_pos_embed:
return residual, shortformer_pos_embed
else:
return residual

def input_to_embed(
self,
input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
Expand Down Expand Up @@ -325,59 +410,21 @@ def input_to_embed(
# We separate this case from for computational efficiency.
attention_mask = None

# If we're doing caching, then we reuse keys and values from previous runs, as that's the
# only way that past activations will affect the final logits. The cache contains those so
# we don't need to recompute them. This is useful for generating text. As we have absolute
# positional encodings, to implement this we have a `pos_offset` variable, defaulting to
# zero, which says to offset which positional encodings are used (cached keys and values
# were calculated with their own positional encodings).
if past_kv_cache is None:
pos_offset = 0
else:
batch_size, ctx_length = tokens.shape
(
cached_batch_size,
cache_ctx_length,
num_heads_in_cache,
d_head_in_cache,
) = past_kv_cache[0].past_keys.shape
assert cached_batch_size == batch_size
if self.cfg.n_key_value_heads is None:
assert num_heads_in_cache == self.cfg.n_heads
else:
assert num_heads_in_cache == self.cfg.n_key_value_heads
assert d_head_in_cache == self.cfg.d_head
pos_offset = cache_ctx_length
batch_size = tokens.shape[0]
pos_offset = self.get_pos_offset(past_kv_cache, batch_size)

if self.cfg.use_hook_tokens:
tokens = self.hook_tokens(tokens)

embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
if self.cfg.positional_embedding_type == "standard":
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed + pos_embed # [batch, pos, d_model]
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "shortformer":
# If we're using shortformer style attention, we don't add the positional embedding to
# the residual stream. See HookedTransformerConfig for details
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed
shortformer_pos_embed = pos_embed
elif self.cfg.positional_embedding_type == "rotary":
# Rotary doesn't use positional embeddings, instead they're applied when dot producting
# keys and queries. See HookedTransformerConfig for details
residual = embed
shortformer_pos_embed = None
elif self.cfg.positional_embedding_type == "alibi":
# ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
residual = embed
shortformer_pos_embed = None
else:
raise ValueError(
f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
)
residual, shortformer_pos_embed = self.get_residual(
embed,
pos_offset,
prepend_bos,
attention_mask,
tokens,
return_shortformer_pos_embed=True,
)
return residual, tokens, shortformer_pos_embed, attention_mask

@overload
Expand Down Expand Up @@ -2139,9 +2186,6 @@ def generate(
input_type = "embeds"

input_tokens = input if input_type in ["str", "tokens"] else None
input = input if input_type == "embeds" else self.embed(input)

assert isinstance(input, torch.Tensor) and input.ndim == 3
batch_size, ctx_length = input.shape[0], input.shape[1]
device = devices.get_device_for_block_index(0, self.cfg)
input = input.to(device)
Expand All @@ -2152,6 +2196,11 @@ def generate(
else:
past_kv_cache = None

shortformer_pos_embed = None
embeds = input if input_type == "embeds" else self.embed(input)

assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3

stop_tokens: List[int] = []
eos_token_for_padding = 0
assert self.tokenizer is not None
Expand Down Expand Up @@ -2184,6 +2233,10 @@ def generate(
self.eval()
sampled_tokens_list = []
for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
residual, shortformer_pos_embed = self.get_residual(
embeds, pos_offset, return_shortformer_pos_embed=True, device=device
)
# While generating, we keep generating logits, throw away all but the final logits,
# and then use those logits to sample from the distribution We keep adding the
# sampled tokens to the end of tokens.
Expand All @@ -2192,31 +2245,34 @@ def generate(
# We just take the final tokens, as a [batch, 1] tensor
if index > 0:
logits = self.forward(
input[:, -1:],
residual[:, -1:],
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
start_at_layer=start_at_layer,
shortformer_pos_embed=shortformer_pos_embed,
)
else:
logits = self.forward(
input,
residual,
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
past_kv_cache=past_kv_cache,
start_at_layer=start_at_layer,
shortformer_pos_embed=shortformer_pos_embed,
)
else:
# We input the entire sequence, as a [batch, pos] tensor, since we aren't using
# the cache.
logits = self.forward(
input,
residual,
return_type="logits",
prepend_bos=prepend_bos,
padding_side=padding_side,
start_at_layer=start_at_layer,
shortformer_pos_embed=shortformer_pos_embed,
)
final_logits = logits[:, -1, :]

Expand Down Expand Up @@ -2258,7 +2314,7 @@ def generate(
)
)

input = torch.hstack([input, self.embed(sampled_tokens.unsqueeze(-1))])
embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))])

if stop_at_eos and finished_sequences.all():
break
Expand All @@ -2278,8 +2334,7 @@ def generate(
elif return_type == "tokens":
return output_tokens
else:
input = input.to("cpu")
return input.to("cpu")
return embeds

# Give access to all weights as properties.
@property
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/components/pos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def forward(
tokens: Int[torch.Tensor, "batch pos"],
past_kv_pos_offset: int = 0,
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
) -> Float[torch.Tensor, "batch pos d_model"]:
) -> Float[torch.Tensor, "batch new_pos d_model"]:
"""
Forward pass for positional embeddings.
Expand Down
2 changes: 2 additions & 0 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,8 @@ def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> to

# Initialize the attention mask with ones (indicating all tokens should be attended to)
attention_mask = torch.ones_like(tokens)
if tokenizer is None:
return attention_mask
is_not_pad_token = tokens.ne(tokenizer.pad_token_id)

if tokenizer.padding_side == "right":
Expand Down

0 comments on commit 3466104

Please sign in to comment.