From 7cd6a670ed16f657d81f03a78e70684cf0b4f7f7 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Wed, 23 Oct 2024 18:07:49 -0500 Subject: [PATCH] RoPE updates (#412) * RoPE updates * Apply suggestions from code review * updates * updates * updates --- .../converting-gpt-to-llama2.ipynb | 8 +- .../converting-llama2-to-llama3.ipynb | 8 +- ch05/07_gpt_to_llama/standalone-llama32.ipynb | 4 +- ch05/07_gpt_to_llama/tests/Untitled.ipynb | 74 +++++++++++ .../tests/test-requirements-extra.txt | 3 +- ch05/07_gpt_to_llama/tests/tests.py | 118 +++++++++++++++++- 6 files changed, 202 insertions(+), 13 deletions(-) create mode 100644 ch05/07_gpt_to_llama/tests/Untitled.ipynb diff --git a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb index e8c5bf68..e7f459ea 100644 --- a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb +++ b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb @@ -426,7 +426,7 @@ " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", + " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", "\n", " # Generate position indices\n", " positions = torch.arange(context_length)\n", @@ -493,8 +493,8 @@ "\n", "# Dummy query and key tensors\n", "torch.manual_seed(123)\n", - "queries = torch.randn(batch_size, context_len, num_heads, head_dim)\n", - "keys = torch.randn(batch_size, context_len, num_heads, head_dim)\n", + "queries = torch.randn(batch_size, num_heads, context_len, head_dim)\n", + "keys = torch.randn(batch_size, num_heads, context_len, head_dim)\n", "\n", "# Apply rotary position embeddings\n", "queries_rot = compute_rope(queries, cos, sin)\n", @@ -1691,7 +1691,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.6" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index 4b4459fc..bf62d9fc 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -278,7 +278,7 @@ " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", + " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", "\n", " ################################ NEW ###############################################\n", " # Frequency adjustments\n", @@ -383,8 +383,8 @@ "\n", "# Dummy query and key tensors\n", "torch.manual_seed(123)\n", - "queries = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n", - "keys = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n", + "queries = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n", + "keys = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n", "\n", "# Apply rotary position embeddings\n", "queries_rot = compute_rope(queries, cos, sin)\n", @@ -2701,7 +2701,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.6" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index 4201f959..b3d80c9e 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -133,7 +133,7 @@ " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", + " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", "\n", " # Frequency adjustments\n", " if freq_config is not None:\n", @@ -1061,7 +1061,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/ch05/07_gpt_to_llama/tests/Untitled.ipynb b/ch05/07_gpt_to_llama/tests/Untitled.ipynb new file mode 100644 index 00000000..1375a9e9 --- /dev/null +++ b/ch05/07_gpt_to_llama/tests/Untitled.ipynb @@ -0,0 +1,74 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "id": "40d2405d-ee10-44ad-b20e-cf32078f926a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True | head dim: 1, tensor([]), tensor([])\n", + "True | head dim: 2, tensor([1.]), tensor([1.])\n", + "True | head dim: 3, tensor([1.]), tensor([1.])\n", + "True | head dim: 4, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0100])\n", + "False | head dim: 5, tensor([1.0000, 0.0100]), tensor([1.0000, 0.0251])\n", + "True | head dim: 6, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0464, 0.0022])\n", + "False | head dim: 7, tensor([1.0000, 0.0464, 0.0022]), tensor([1.0000, 0.0720, 0.0052])\n", + "True | head dim: 8, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1000, 0.0100, 0.0010])\n", + "False | head dim: 9, tensor([1.0000, 0.1000, 0.0100, 0.0010]), tensor([1.0000, 0.1292, 0.0167, 0.0022])\n", + "True | head dim: 10, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04])\n", + "False | head dim: 11, tensor([1.0000e+00, 1.5849e-01, 2.5119e-02, 3.9811e-03, 6.3096e-04]), tensor([1.0000, 0.1874, 0.0351, 0.0066, 0.0012])\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "theta_base = 10_000\n", + "\n", + "for head_dim in range(1, 12):\n", + "\n", + " before = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", + " after = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", + " \n", + " s = f\"{torch.equal(before, after)} | head dim: {head_dim}, {before}, {after}\"\n", + " print(s)\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0abfbf38-93a4-4994-8e7e-a543477268a8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt b/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt index 8828ccea..2b9fd336 100644 --- a/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt +++ b/ch05/07_gpt_to_llama/tests/test-requirements-extra.txt @@ -1 +1,2 @@ -transformers>=4.44.2 \ No newline at end of file +transformers>=4.44.2 +litgpt>=0.5.0 \ No newline at end of file diff --git a/ch05/07_gpt_to_llama/tests/tests.py b/ch05/07_gpt_to_llama/tests/tests.py index 6620b4ea..395f9ec3 100644 --- a/ch05/07_gpt_to_llama/tests/tests.py +++ b/ch05/07_gpt_to_llama/tests/tests.py @@ -10,11 +10,82 @@ import sys import types import nbformat +from typing import Optional, Tuple import torch import pytest from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb +# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py +# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE +def litgpt_build_rope_cache( + seq_len: int, + n_elem: int, + device: Optional[torch.device] = None, + base: int = 10000, + condense_ratio: int = 1, + extra_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Enhanced Transformer with Rotary Position Embedding. + + Args: + seq_len (int): Sequence length. + n_elem (int): Number of elements (head dimension). + device (torch.device, optional): Device for tensor allocations. + base (int, optional): Base for computing inverse frequencies. + condense_ratio (int, optional): Ratio to condense the position indices. + extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE. + """ + + # Compute the inverse frequencies theta + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + + if extra_config is not None: + orig_context_len = extra_config["original_max_seq_len"] + factor = extra_config["factor"] + low_freq_factor = extra_config["low_freq_factor"] + high_freq_factor = extra_config["high_freq_factor"] + + wavelen = 2 * torch.pi / theta + ratio = orig_context_len / wavelen + smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) + + # Compute adjusted_theta without masked indexing + adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta + theta = adjusted_theta + + # Create position indices `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + + return torch.cos(idx_theta), torch.sin(idx_theta) + + +# LitGPT code from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py +# LitGPT is licensed under Apache v2: https://github.com/Lightning-AI/litgpt/blob/main/LICENSE +def litgpt_apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + head_size = x.size(-1) + x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) + x2 = x[..., head_size // 2:] # (B, nh, T, hs/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + if cos.dim() > 1: + # batch dimensions must align + # sin/cos are (B, T, hs) so we unsqeeze -3 for nh + # we count from back because all of apply_rope does + cos = cos.unsqueeze(-3) + sin = sin.unsqueeze(-3) + + roped = (x * cos) + (rotated * sin) + return roped.to(dtype=x.dtype) + + @pytest.fixture(scope="module") def notebook(): def import_definitions_from_notebook(notebooks): @@ -84,21 +155,30 @@ def test_rope_llama2(notebook): queries_rot = this_nb.compute_rope(queries, cos, sin) keys_rot = this_nb.compute_rope(keys, cos, sin) + # Generate reference RoPE via HF rot_emb = LlamaRotaryEmbedding( dim=head_dim, max_position_embeddings=context_len, base=10_000 ) - position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) ref_cos, ref_sin = rot_emb(queries, position_ids) ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) - torch.testing.assert_close(sin, ref_sin.squeeze(0)) torch.testing.assert_close(cos, ref_cos.squeeze(0)) torch.testing.assert_close(keys_rot, ref_keys_rot) torch.testing.assert_close(queries_rot, ref_queries_rot) + # Generate reference RoPE via LitGPT + litgpt_cos, litgpt_sin = litgpt_build_rope_cache(context_len, n_elem=head_dim, base=10_000) + litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin) + litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin) + + torch.testing.assert_close(sin, litgpt_sin) + torch.testing.assert_close(cos, litgpt_cos) + torch.testing.assert_close(keys_rot, litgpt_keys_rot) + torch.testing.assert_close(queries_rot, litgpt_queries_rot) + def test_rope_llama3(notebook): @@ -128,6 +208,7 @@ def test_rope_llama3(notebook): queries_rot = nb1.compute_rope(queries, cos, sin) keys_rot = nb1.compute_rope(keys, cos, sin) + # Generate reference RoPE via HF rot_emb = LlamaRotaryEmbedding( dim=head_dim, max_position_embeddings=context_len, @@ -143,6 +224,16 @@ def test_rope_llama3(notebook): torch.testing.assert_close(keys_rot, ref_keys_rot) torch.testing.assert_close(queries_rot, ref_queries_rot) + # Generate reference RoPE via LitGPT + litgpt_cos, litgpt_sin = litgpt_build_rope_cache(context_len, n_elem=head_dim, base=theta_base) + litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin) + litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin) + + torch.testing.assert_close(sin, litgpt_sin) + torch.testing.assert_close(cos, litgpt_cos) + torch.testing.assert_close(keys_rot, litgpt_keys_rot) + torch.testing.assert_close(queries_rot, litgpt_queries_rot) + def test_rope_llama3_12(notebook): @@ -180,6 +271,7 @@ def test_rope_llama3_12(notebook): queries_rot = nb1.compute_rope(queries, cos, sin) keys_rot = nb1.compute_rope(keys, cos, sin) + # Generate reference RoPE via HF hf_rope_params = { "factor": 8.0, "low_freq_factor": 1.0, @@ -210,6 +302,28 @@ class RoPEConfig: torch.testing.assert_close(keys_rot, ref_keys_rot) torch.testing.assert_close(queries_rot, ref_queries_rot) + # Generate reference RoPE via LitGPT + litgpt_rope_config = { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_seq_len": 8192 + } + + litgpt_cos, litgpt_sin = litgpt_build_rope_cache( + context_len, + n_elem=head_dim, + base=rope_theta, + extra_config=litgpt_rope_config + ) + litgpt_queries_rot = litgpt_apply_rope(queries, litgpt_cos, litgpt_sin) + litgpt_keys_rot = litgpt_apply_rope(keys, litgpt_cos, litgpt_sin) + + torch.testing.assert_close(sin, litgpt_sin) + torch.testing.assert_close(cos, litgpt_cos) + torch.testing.assert_close(keys_rot, litgpt_keys_rot) + torch.testing.assert_close(queries_rot, litgpt_queries_rot) + def test_silu(notebook): example_batch = torch.randn(2, 3, 4)