Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE updates #412

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length)\n",
Expand Down Expand Up @@ -493,8 +493,8 @@
"\n",
"# Dummy query and key tensors\n",
"torch.manual_seed(123)\n",
"queries = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
"keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n",
"queries = torch.randn(batch_size, num_heads, context_len, head_dim)\n",
"keys = torch.randn(batch_size, num_heads, context_len, head_dim)\n",
"\n",
"# Apply rotary position embeddings\n",
"queries_rot = compute_rope(queries, cos, sin)\n",
Expand Down Expand Up @@ -1691,7 +1691,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
8 changes: 4 additions & 4 deletions ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
"\n",
" ################################ NEW ###############################################\n",
" # Frequency adjustments\n",
Expand Down Expand Up @@ -383,8 +383,8 @@
"\n",
"# Dummy query and key tensors\n",
"torch.manual_seed(123)\n",
"queries = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n",
"keys = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n",
"queries = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n",
"keys = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n",
"\n",
"# Apply rotary position embeddings\n",
"queries_rot = compute_rope(queries, cos, sin)\n",
Expand Down Expand Up @@ -2701,7 +2701,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
4 changes: 2 additions & 2 deletions ch05/07_gpt_to_llama/standalone-llama32.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
"\n",
" # Frequency adjustments\n",
" if freq_config is not None:\n",
Expand Down Expand Up @@ -1061,7 +1061,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
74 changes: 74 additions & 0 deletions ch05/07_gpt_to_llama/tests/Untitled.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"id": "40d2405d-ee10-44ad-b20e-cf32078f926a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True | head dim: 1, tensor([]), tensor([])\n",
"True | head dim: 2, tensor([1.]), tensor([1.])\n",
"True | head dim: 3, tensor([1.]), tensor([1.])\n",
"True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n",
"False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n",
"True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n",
"False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n",
"True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n",
"False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n",
"True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n",
"False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n"
]
}
],
"source": [
"import torch\n",
"\n",
"theta_base = 10_000\n",
"\n",
"for head_dim in range(1, 12):\n",
"\n",
" before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
" after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
" \n",
" s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n",
" print(s)\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0abfbf38-93a4-4994-8e7e-a543477268a8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 2 additions & 1 deletion ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
transformers>=4.44.2
transformers>=4.44.2
litgpt>=0.5.0
118 changes: 116 additions & 2 deletions ch05/07_gpt_to_llama/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,82 @@
import sys
import types
import nbformat
from typing import Optional, Tuple
import torch
import pytest
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb


# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
def litgpt_build_rope_cache(
seq_len: int,
n_elem: int,
device: Optional[torch.device] = None,
base: int = 10000,
condense_ratio: int = 1,
extra_config: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Enhanced Transformer with Rotary Position Embedding.

Args:
seq_len (int): Sequence length.
n_elem (int): Number of elements (head dimension).
device (torch.device, optional): Device for tensor allocations.
base (int, optional): Base for computing inverse frequencies.
condense_ratio (int, optional): Ratio to condense the position indices.
extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2)

Returns:
Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
"""

# Compute the inverse frequencies theta
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

if extra_config is not None:
orig_context_len = extra_config["original_max_seq_len"]
factor = extra_config["factor"]
low_freq_factor = extra_config["low_freq_factor"]
high_freq_factor = extra_config["high_freq_factor"]

wavelen = 2 * torch.pi / theta
ratio = orig_context_len / wavelen
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)

# Compute adjusted_theta without masked indexing
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta
theta = adjusted_theta

# Create position indices `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)

return torch.cos(idx_theta), torch.sin(idx_theta)


# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
def litgpt_apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2:] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
if cos.dim() > 1:
# batch dimensions must align
# sin/cos are (B, T, hs) so we unsqeeze -3 for nh
# we count from back because all of apply_rope does
cos = cos.unsqueeze(-3)
sin = sin.unsqueeze(-3)

roped = (x * cos) + (rotated * sin)
return roped.to(dtype=x.dtype)


@pytest.fixture(scope="module")
def notebook():
def import_definitions_from_notebook(notebooks):
Expand Down Expand Up @@ -84,21 +155,30 @@ def test_rope_llama2(notebook):
queries_rot = this_nb.compute_rope(queries, cos, sin)
keys_rot = this_nb.compute_rope(keys, cos, sin)

# Generate reference RoPE via HF
rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
max_position_embeddings=context_len,
base=10_000
)

position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)

torch.testing.assert_close(sin, ref_sin.squeeze(0))
torch.testing.assert_close(cos, ref_cos.squeeze(0))
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)

# Generate reference RoPE via LitGPT
litgpt_cos, litgpt_sin = litgpt_build_rope_cache(context_len, n_elem=head_dim, base=10_000)
litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)

torch.testing.assert_close(sin, litgpt_sin)
torch.testing.assert_close(cos, litgpt_cos)
torch.testing.assert_close(keys_rot, litgpt_keys_rot)
torch.testing.assert_close(queries_rot, litgpt_queries_rot)


def test_rope_llama3(notebook):

Expand Down Expand Up @@ -128,6 +208,7 @@ def test_rope_llama3(notebook):
queries_rot = nb1.compute_rope(queries, cos, sin)
keys_rot = nb1.compute_rope(keys, cos, sin)

# Generate reference RoPE via HF
rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
max_position_embeddings=context_len,
Expand All @@ -143,6 +224,16 @@ def test_rope_llama3(notebook):
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)

# Generate reference RoPE via LitGPT
litgpt_cos, litgpt_sin = litgpt_build_rope_cache(context_len, n_elem=head_dim, base=theta_base)
litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)

torch.testing.assert_close(sin, litgpt_sin)
torch.testing.assert_close(cos, litgpt_cos)
torch.testing.assert_close(keys_rot, litgpt_keys_rot)
torch.testing.assert_close(queries_rot, litgpt_queries_rot)


def test_rope_llama3_12(notebook):

Expand Down Expand Up @@ -180,6 +271,7 @@ def test_rope_llama3_12(notebook):
queries_rot = nb1.compute_rope(queries, cos, sin)
keys_rot = nb1.compute_rope(keys, cos, sin)

# Generate reference RoPE via HF
hf_rope_params = {
"factor": 8.0,
"low_freq_factor": 1.0,
Expand Down Expand Up @@ -210,6 +302,28 @@ class RoPEConfig:
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)

# Generate reference RoPE via LitGPT
litgpt_rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_seq_len": 8192
}

litgpt_cos, litgpt_sin = litgpt_build_rope_cache(
context_len,
n_elem=head_dim,
base=rope_theta,
extra_config=litgpt_rope_config
)
litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin)
litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin)

torch.testing.assert_close(sin, litgpt_sin)
torch.testing.assert_close(cos, litgpt_cos)
torch.testing.assert_close(keys_rot, litgpt_keys_rot)
torch.testing.assert_close(queries_rot, litgpt_queries_rot)


def test_silu(notebook):
example_batch = torch.randn(2, 3, 4)
Expand Down