diff --git a/README.md b/README.md
index d09e7605..80aa32dd 100644
--- a/README.md
+++ b/README.md
@@ -120,6 +120,7 @@ Several folders contain optional materials as a bonus for interested readers:
- [Converting GPT to Llama](ch05/07_gpt_to_llama)
- [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb)
- [Memory-efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb)
+ - [Extending the Tiktoken BPE Tokenizer with New Tokens](ch05/09_extending-tokenizers/extend-tiktoken.ipynb)
- **Chapter 6: Finetuning for classification**
- [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments)
- [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification)
diff --git a/ch02/05_bpe-from-scratch/README.md b/ch02/05_bpe-from-scratch/README.md
new file mode 100644
index 00000000..0d684679
--- /dev/null
+++ b/ch02/05_bpe-from-scratch/README.md
@@ -0,0 +1,3 @@
+# Byte Pair Encoding (BPE) Tokenizer From Scratch
+
+- [bpe-from-scratch.ipynb](bpe-from-scratch.ipynb) contains optional (bonus) code that explains and shows how the BPE tokenizer works under the hood.
diff --git a/ch05/09_extending-tokenizers/README.md b/ch05/09_extending-tokenizers/README.md
new file mode 100644
index 00000000..886a880b
--- /dev/null
+++ b/ch05/09_extending-tokenizers/README.md
@@ -0,0 +1,3 @@
+# Extending the Tiktoken BPE Tokenizer with New Tokens
+
+- [extend-tiktoken.ipynb](extend-tiktoken.ipynb) contains optional (bonus) code to explain how we can add special tokens to a tokenizer implemented via `tiktoken` and how to update the LLM accordingly
\ No newline at end of file
diff --git a/ch05/09_extending-tokenizers/extend-tiktoken.ipynb b/ch05/09_extending-tokenizers/extend-tiktoken.ipynb
new file mode 100644
index 00000000..83d40a49
--- /dev/null
+++ b/ch05/09_extending-tokenizers/extend-tiktoken.ipynb
@@ -0,0 +1,771 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "cbbc1fe3-bff1-4631-bf35-342e19c54cc0",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka \n",
+ " Code repository: https://github.com/rasbt/LLMs-from-scratch\n",
+ "\n",
+ " | \n",
+ "\n",
+ "\n",
+ " | \n",
+ "
\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2b022374-e3f6-4437-b86f-e6f8f94cbebc",
+ "metadata": {},
+ "source": [
+ "# Extending the Tiktoken BPE Tokenizer with New Tokens"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bcd624b1-2060-49af-bbf6-40517a58c128",
+ "metadata": {},
+ "source": [
+ "- This notebook explains how we can extend an existing BPE tokenizer; specifically, we will focus on how to do it for the popular [tiktoken](https://github.com/openai/tiktoken) implementation\n",
+ "- For a general introduction to tokenization, please refer to [Chapter 2](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/ch02.ipynb) and the BPE from Scratch [link] tutorial\n",
+ "- For example, suppose we have a GPT-2 tokenizer and want to encode the following text:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "798d4355-a146-48a8-a1a5-c5cec91edf2c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[15496, 11, 2011, 3791, 30642, 62, 16, 318, 257, 649, 11241, 13, 220, 50256]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import tiktoken\n",
+ "\n",
+ "base_tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
+ "sample_text = \"Hello, MyNewToken_1 is a new token. <|endoftext|>\"\n",
+ "\n",
+ "token_ids = base_tokenizer.encode(sample_text, allowed_special={\"<|endoftext|>\"})\n",
+ "print(token_ids)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5b09b19b-772d-4449-971b-8ab052ee726d",
+ "metadata": {},
+ "source": [
+ "- Iterating over each token ID can give us a better understanding of how the token IDs are decoded via the vocabulary:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "21fd634b-bb4c-4ba3-8b69-9322b727bf58",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "15496 -> Hello\n",
+ "11 -> ,\n",
+ "2011 -> My\n",
+ "3791 -> New\n",
+ "30642 -> Token\n",
+ "62 -> _\n",
+ "16 -> 1\n",
+ "318 -> is\n",
+ "257 -> a\n",
+ "649 -> new\n",
+ "11241 -> token\n",
+ "13 -> .\n",
+ "220 -> \n",
+ "50256 -> <|endoftext|>\n"
+ ]
+ }
+ ],
+ "source": [
+ "for token_id in token_ids:\n",
+ " print(f\"{token_id} -> {base_tokenizer.decode([token_id])}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fd5b1b9b-b1a9-489e-9711-c15a8e081813",
+ "metadata": {},
+ "source": [
+ "- As we can see above, the `\"MyNewToken_1\"` is broken down into 5 individual subword tokens -- this is normal behavior for BPE when handling unknown words\n",
+ "- However, suppose that it's a special token that we want to encode as a single token, similar to some of the other words or `\"<|endoftext|>\"`; this notebook explains how"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "65f62ab6-df96-4f88-ab9a-37702cd30f5f",
+ "metadata": {},
+ "source": [
+ " \n",
+ "## 1. Adding special tokens"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c4379fdb-57ba-4a75-9183-0aee0836c391",
+ "metadata": {},
+ "source": [
+ "- Note that we have to add new tokens as special tokens; the reason is that we don't have the \"merges\" for the new tokens that are created during the tokenizer training process -- even if we had them, it would be very challenging to incorporate them without breaking the existing tokenization scheme (see the BPE from scratch notebook [link] to understand the \"merges\")\n",
+ "- Suppose we want to add 2 new tokens:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "265f1bba-c478-497d-b7fc-f4bd191b7d55",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define custom tokens and their token IDs\n",
+ "custom_tokens = [\"MyNewToken_1\", \"MyNewToken_2\"]\n",
+ "custom_token_ids = {\n",
+ " token: base_tokenizer.n_vocab + i for i, token in enumerate(custom_tokens)\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1c6f3d98-1ab6-43cf-9ae2-2bf53860f99e",
+ "metadata": {},
+ "source": [
+ "- Next, we create a custom `Encoding` object that holds our special tokens as follows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "1f519852-59ea-4069-a8c7-0f647bfaea09",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a new Encoding object with extended tokens\n",
+ "extended_tokenizer = tiktoken.Encoding(\n",
+ " name=\"gpt2_custom\",\n",
+ " pat_str=base_tokenizer._pat_str,\n",
+ " mergeable_ranks=base_tokenizer._mergeable_ranks,\n",
+ " special_tokens={**base_tokenizer._special_tokens, **custom_token_ids},\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "90af6cfa-e0cc-4c80-89dc-3a824e7bdeb2",
+ "metadata": {},
+ "source": [
+ "- That's it, we can now check that it can encode the sample text:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "153e8e1d-c4cb-41ff-9c55-1701e9bcae1c",
+ "metadata": {},
+ "source": [
+ "- As we can see, the new tokens `50257` and `50258` are now encoded in the output:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "eccc78a4-1fd4-47ba-a114-83ee0a3aec31",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[36674, 2420, 351, 220, 50257, 290, 220, 50258, 13, 220, 50256]\n"
+ ]
+ }
+ ],
+ "source": [
+ "special_tokens_set = set(custom_tokens) | {\"<|endoftext|>\"}\n",
+ "\n",
+ "token_ids = extended_tokenizer.encode(\n",
+ " \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\",\n",
+ " allowed_special=special_tokens_set\n",
+ ")\n",
+ "print(token_ids)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dc0547c1-bbb5-4915-8cf4-caaebcf922eb",
+ "metadata": {},
+ "source": [
+ "- Again, we can also look at it on a per-token level:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "7583eff9-b10d-4e3d-802c-f0464e1ef030",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "36674 -> Sample\n",
+ "2420 -> text\n",
+ "351 -> with\n",
+ "220 -> \n",
+ "50257 -> MyNewToken_1\n",
+ "290 -> and\n",
+ "220 -> \n",
+ "50258 -> MyNewToken_2\n",
+ "13 -> .\n",
+ "220 -> \n",
+ "50256 -> <|endoftext|>\n"
+ ]
+ }
+ ],
+ "source": [
+ "for token_id in token_ids:\n",
+ " print(f\"{token_id} -> {extended_tokenizer.decode([token_id])}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "17f0764e-e5a9-4226-a384-18c11bd5fec3",
+ "metadata": {},
+ "source": [
+ "- As we can see above, we have successfully updated the tokenizer\n",
+ "- However, to use it with a pretrained LLM, we also have to update the embedding and output layers of the LLM, which is discussed in the next section"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8ec7f98d-8f09-4386-83f0-9bec68ef7f66",
+ "metadata": {},
+ "source": [
+ " \n",
+ "## 2. Updating a pretrained LLM"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b8a4f68b-04e9-4524-8df4-8718c7b566f2",
+ "metadata": {},
+ "source": [
+ "- In this section, we will take a look at how we have to update an existing pretrained LLM after updating the tokenizer\n",
+ "- For this, we are using the original pretrained GPT-2 model that is used in the main book"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1a9b252e-1d1d-4ddf-b9f3-95bd6ba505a9",
+ "metadata": {},
+ "source": [
+ " \n",
+ "### 2.1 Loading a pretrained GPT model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "ded29b4e-9b39-4191-b61c-29d6b2360bae",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 34.4kiB/s]\n",
+ "encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 4.78MiB/s]\n",
+ "hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 24.7kiB/s]\n",
+ "model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [00:33<00:00, 14.7MiB/s]\n",
+ "model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 1.05MiB/s]\n",
+ "model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 2.33MiB/s]\n",
+ "vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.45MiB/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Relative import from the gpt_download.py contained in this folder\n",
+ "from gpt_download import download_and_load_gpt2\n",
+ "\n",
+ "settings, params = download_and_load_gpt2(model_size=\"124M\", models_dir=\"gpt2\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "93dc0d8e-b549-415b-840e-a00023bddcf9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Relative import from the gpt_download.py contained in this folder\n",
+ "from previous_chapters import GPTModel\n",
+ "\n",
+ "GPT_CONFIG_124M = {\n",
+ " \"vocab_size\": 50257, # Vocabulary size\n",
+ " \"context_length\": 256, # Shortened context length (orig: 1024)\n",
+ " \"emb_dim\": 768, # Embedding dimension\n",
+ " \"n_heads\": 12, # Number of attention heads\n",
+ " \"n_layers\": 12, # Number of layers\n",
+ " \"drop_rate\": 0.1, # Dropout rate\n",
+ " \"qkv_bias\": False # Query-key-value bias\n",
+ "}\n",
+ "\n",
+ "# Define model configurations in a dictionary for compactness\n",
+ "model_configs = {\n",
+ " \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
+ " \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
+ " \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
+ " \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
+ "}\n",
+ "\n",
+ "# Copy the base configuration and update with specific model settings\n",
+ "model_name = \"gpt2-small (124M)\" # Example model name\n",
+ "NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
+ "NEW_CONFIG.update(model_configs[model_name])\n",
+ "NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
+ "\n",
+ "gpt = GPTModel(NEW_CONFIG)\n",
+ "gpt.eval();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "83f898c0-18f4-49ce-9b1f-3203a277b29e",
+ "metadata": {},
+ "source": [
+ "### 2.2 Using the pretrained GPT model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a5a1f5e1-e806-4c60-abaa-42ae8564908c",
+ "metadata": {},
+ "source": [
+ "- Next, consider our sample text below, which we tokenize using the original and the new tokenizer:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "9a88017d-cc8f-4ba1-bba9-38161a30f673",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "sample_text = \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\"\n",
+ "\n",
+ "original_token_ids = base_tokenizer.encode(\n",
+ " sample_text, allowed_special={\"<|endoftext|>\"}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "1ee01bc3-ca24-497b-b540-3d13c52c29ed",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_token_ids = extended_tokenizer.encode(\n",
+ " \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\",\n",
+ " allowed_special=special_tokens_set\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1143106b-68fe-4234-98ad-eaff420a4d08",
+ "metadata": {},
+ "source": [
+ "- Now, let's feed the original token IDs to the GPT model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "6b06827f-b411-42cc-b978-5c1d568a3200",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([[[ 0.2204, 0.8901, 1.0138, ..., 0.2585, -0.9192, -0.2298],\n",
+ " [ 0.6745, -0.0726, 0.8218, ..., -0.1768, -0.4217, 0.0703],\n",
+ " [-0.2009, 0.0814, 0.2417, ..., 0.3166, 0.3629, 1.3400],\n",
+ " ...,\n",
+ " [ 0.1137, -0.1258, 2.0193, ..., -0.0314, -0.4288, -0.1487],\n",
+ " [-1.1983, -0.2050, -0.1337, ..., -0.0849, -0.4863, -0.1076],\n",
+ " [-1.0675, -0.5905, 0.2873, ..., -0.0979, -0.8713, 0.8415]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " out = gpt(torch.tensor([original_token_ids]))\n",
+ "\n",
+ "print(out)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "082c7a78-35a8-473e-a08d-b099a6348a74",
+ "metadata": {},
+ "source": [
+ "- As we can see above, this works without problems (note that the code shows the raw output without converting the outputs back into text for simplicity; for more details on that, please check out the `generate` function in Chapter 5 [link] section 5.3.3"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "628265b5-3dde-44e7-bde2-8fc594a2547d",
+ "metadata": {},
+ "source": [
+ "- What happens if we try the same on the token IDs generated by the updated tokenizer now?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9796ad09-787c-4c25-a7f5-6d1dfe048ac3",
+ "metadata": {},
+ "source": [
+ "```python\n",
+ "with torch.no_grad():\n",
+ " gpt(torch.tensor([new_token_ids]))\n",
+ "\n",
+ "print(out)\n",
+ "\n",
+ "...\n",
+ "# IndexError: index out of range in self\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "77d00244-7e40-4de0-942e-e15cdd8e3b18",
+ "metadata": {},
+ "source": [
+ "- As we can see, this results in an index error\n",
+ "- The reason is that the GPT model expects a fixed vocabulary size via its input embedding layer and its output layer:\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dec38b24-c845-4090-96a4-0d3c4ec241d6",
+ "metadata": {},
+ "source": [
+ " \n",
+ "### 2.3 Updating the embedding layer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1328726-8297-4162-878b-a5daff7de742",
+ "metadata": {},
+ "source": [
+ "- Let's start with updating the embedding layer\n",
+ "- First, notice that the embedding layer has 50,257 entries, which corresponds to the vocabulary size:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "23ecab6e-1232-47c7-a318-042f90e1dff3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Embedding(50257, 768)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gpt.tok_emb"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d760c683-d082-470a-bff8-5a08b30d3b61",
+ "metadata": {},
+ "source": [
+ "- We want to extend this embedding layer by adding 2 more entries\n",
+ "- In short, we create a new embedding layer with a bigger size, and then we copy over the old embedding layer values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "4ec5c48e-c6fe-4e84-b290-04bd4da9483f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Embedding(50259, 768)\n"
+ ]
+ }
+ ],
+ "source": [
+ "num_tokens, emb_size = gpt.tok_emb.weight.shape\n",
+ "new_num_tokens = num_tokens + 2\n",
+ "\n",
+ "# Create a new embedding layer\n",
+ "new_embedding = torch.nn.Embedding(new_num_tokens, emb_size)\n",
+ "\n",
+ "# Copy weights from the old embedding layer\n",
+ "new_embedding.weight.data[:num_tokens] = gpt.tok_emb.weight.data\n",
+ "\n",
+ "# Replace the old embedding layer with the new one in the model\n",
+ "gpt.tok_emb = new_embedding\n",
+ "\n",
+ "print(gpt.tok_emb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "63954928-31a5-4e7e-9688-2e0c156b7302",
+ "metadata": {},
+ "source": [
+ "- As we can see above, we now have an increased embedding layer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6e68bea5-255b-47bb-b352-09ea9539bc25",
+ "metadata": {},
+ "source": [
+ " \n",
+ "### 2.4 Updating the output layer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "90a4a519-bf0f-4502-912d-ef0ac7a9deab",
+ "metadata": {},
+ "source": [
+ "- Next, we have to extend the output layer, which has 50,257 output features corresponding to the vocabulary size similar to the embedding layer (by the way, you may find the bonus material, which discusses the similarity between Linear and Embedding layers in PyTorch, useful)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "6105922f-d889-423e-bbcc-bc49156d78df",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Linear(in_features=768, out_features=50257, bias=False)"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gpt.out_head"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "29f1ff24-9c00-40f6-a94f-82d03aaf0890",
+ "metadata": {},
+ "source": [
+ "- The procedure for extending the output layer is similar to extending the embedding layer:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "354589db-b148-4dae-8068-62132e3fb38e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Linear(in_features=768, out_features=50259, bias=True)\n"
+ ]
+ }
+ ],
+ "source": [
+ "original_out_features, original_in_features = gpt.out_head.weight.shape\n",
+ "\n",
+ "# Define the new number of output features (e.g., adding 2 new tokens)\n",
+ "new_out_features = original_out_features + 2\n",
+ "\n",
+ "# Create a new linear layer with the extended output size\n",
+ "new_linear = torch.nn.Linear(original_in_features, new_out_features)\n",
+ "\n",
+ "# Copy the weights and biases from the original linear layer\n",
+ "with torch.no_grad():\n",
+ " new_linear.weight[:original_out_features] = gpt.out_head.weight\n",
+ " if gpt.out_head.bias is not None:\n",
+ " new_linear.bias[:original_out_features] = gpt.out_head.bias\n",
+ "\n",
+ "# Replace the original linear layer with the new one\n",
+ "gpt.out_head = new_linear\n",
+ "\n",
+ "print(gpt.out_head)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "df5d2205-1fae-4a4f-a7bd-fa8fc37eeec2",
+ "metadata": {},
+ "source": [
+ "- Let's try this updated model on the original token IDs first:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "df604bbc-6c13-4792-8ba8-ecb692117c25",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([[[ 0.2267, 0.9132, 1.0494, ..., -0.2330, -0.3008, -1.1458],\n",
+ " [ 0.6808, -0.0495, 0.8574, ..., 0.0671, 0.5572, -0.7873],\n",
+ " [-0.1947, 0.1045, 0.2773, ..., 1.3368, 0.8479, -0.9660],\n",
+ " ...,\n",
+ " [ 0.1200, -0.1027, 2.0549, ..., -0.1519, -0.2096, 0.5651],\n",
+ " [-1.1920, -0.1819, -0.0981, ..., -0.1108, 0.8435, -0.3771],\n",
+ " [-1.0612, -0.5674, 0.3229, ..., 0.8383, -0.7121, -0.4850]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " output = gpt(torch.tensor([original_token_ids]))\n",
+ "print(output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3d80717e-50e6-4927-8129-0aadfa2628f5",
+ "metadata": {},
+ "source": [
+ "- Next, let's try it on the updated tokens:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "75f11ec9-bdd2-440f-b8c8-6646b75891c6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([[[ 0.2267, 0.9132, 1.0494, ..., -0.2330, -0.3008, -1.1458],\n",
+ " [ 0.6808, -0.0495, 0.8574, ..., 0.0671, 0.5572, -0.7873],\n",
+ " [-0.1947, 0.1045, 0.2773, ..., 1.3368, 0.8479, -0.9660],\n",
+ " ...,\n",
+ " [-0.0656, -1.2451, 0.7957, ..., -1.2124, 0.1044, 0.5088],\n",
+ " [-1.1561, -0.7380, -0.0645, ..., -0.4373, 1.1401, -0.3903],\n",
+ " [-0.8961, -0.6437, -0.1667, ..., 0.5663, -0.5862, -0.4020]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " output = gpt(torch.tensor([new_token_ids]))\n",
+ "print(output)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d88a1bba-db01-4090-97e4-25dfc23ed54c",
+ "metadata": {},
+ "source": [
+ "- As we can see, the model works on the extended token set\n",
+ "- In practice, we want to now finetune (or continually pretrain) the model (specifically the new embedding and output layers) on data containing the new tokens"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6de573ad-0338-40d9-9dad-de60ae349c4f",
+ "metadata": {},
+ "source": [
+ "**A note about weight tying**\n",
+ "\n",
+ "- If the model uses weight tying, which means that the embedding layer and output layer share the same weights, similar to Llama 3 [link], updating the output layer is much simpler\n",
+ "- In this case, we can simply copy over the weights from the embedding layer:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "4cbc5f51-c7a8-49d0-b87f-d3d87510953b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gpt.out_head.weight = gpt.tok_emb.weight"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "d0d553a8-edff-40f0-bdc4-dff900e16caf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " output = gpt(torch.tensor([new_token_ids]))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/ch05/09_extending-tokenizers/gpt_download.py b/ch05/09_extending-tokenizers/gpt_download.py
new file mode 100644
index 00000000..aa0ea1e3
--- /dev/null
+++ b/ch05/09_extending-tokenizers/gpt_download.py
@@ -0,0 +1,142 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+
+import os
+import urllib.request
+
+# import requests
+import json
+import numpy as np
+import tensorflow as tf
+from tqdm import tqdm
+
+
+def download_and_load_gpt2(model_size, models_dir):
+ # Validate model size
+ allowed_sizes = ("124M", "355M", "774M", "1558M")
+ if model_size not in allowed_sizes:
+ raise ValueError(f"Model size not in {allowed_sizes}")
+
+ # Define paths
+ model_dir = os.path.join(models_dir, model_size)
+ base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
+ filenames = [
+ "checkpoint", "encoder.json", "hparams.json",
+ "model.ckpt.data-00000-of-00001", "model.ckpt.index",
+ "model.ckpt.meta", "vocab.bpe"
+ ]
+
+ # Download files
+ os.makedirs(model_dir, exist_ok=True)
+ for filename in filenames:
+ file_url = os.path.join(base_url, model_size, filename)
+ file_path = os.path.join(model_dir, filename)
+ download_file(file_url, file_path)
+
+ # Load settings and params
+ tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
+ settings = json.load(open(os.path.join(model_dir, "hparams.json")))
+ params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
+
+ return settings, params
+
+
+def download_file(url, destination):
+ # Send a GET request to download the file
+
+ try:
+ with urllib.request.urlopen(url) as response:
+ # Get the total file size from headers, defaulting to 0 if not present
+ file_size = int(response.headers.get("Content-Length", 0))
+
+ # Check if file exists and has the same size
+ if os.path.exists(destination):
+ file_size_local = os.path.getsize(destination)
+ if file_size == file_size_local:
+ print(f"File already exists and is up-to-date: {destination}")
+ return
+
+ # Define the block size for reading the file
+ block_size = 1024 # 1 Kilobyte
+
+ # Initialize the progress bar with total file size
+ progress_bar_description = os.path.basename(url) # Extract filename from URL
+ with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
+ # Open the destination file in binary write mode
+ with open(destination, "wb") as file:
+ # Read the file in chunks and write to destination
+ while True:
+ chunk = response.read(block_size)
+ if not chunk:
+ break
+ file.write(chunk)
+ progress_bar.update(len(chunk)) # Update progress bar
+ except urllib.error.HTTPError:
+ s = (
+ f"The specified URL ({url}) is incorrect, the internet connection cannot be established,"
+ "\nor the requested file is temporarily unavailable.\nPlease visit the following website"
+ " for help: https://github.com/rasbt/LLMs-from-scratch/discussions/273")
+ print(s)
+
+
+# Alternative way using `requests`
+"""
+def download_file(url, destination):
+ # Send a GET request to download the file in streaming mode
+ response = requests.get(url, stream=True)
+
+ # Get the total file size from headers, defaulting to 0 if not present
+ file_size = int(response.headers.get("content-length", 0))
+
+ # Check if file exists and has the same size
+ if os.path.exists(destination):
+ file_size_local = os.path.getsize(destination)
+ if file_size == file_size_local:
+ print(f"File already exists and is up-to-date: {destination}")
+ return
+
+ # Define the block size for reading the file
+ block_size = 1024 # 1 Kilobyte
+
+ # Initialize the progress bar with total file size
+ progress_bar_description = url.split("/")[-1] # Extract filename from URL
+ with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
+ # Open the destination file in binary write mode
+ with open(destination, "wb") as file:
+ # Iterate over the file data in chunks
+ for chunk in response.iter_content(block_size):
+ progress_bar.update(len(chunk)) # Update progress bar
+ file.write(chunk) # Write the chunk to the file
+"""
+
+
+def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
+ # Initialize parameters dictionary with empty blocks for each layer
+ params = {"blocks": [{} for _ in range(settings["n_layer"])]}
+
+ # Iterate over each variable in the checkpoint
+ for name, _ in tf.train.list_variables(ckpt_path):
+ # Load the variable and remove singleton dimensions
+ variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
+
+ # Process the variable name to extract relevant parts
+ variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
+
+ # Identify the target dictionary for the variable
+ target_dict = params
+ if variable_name_parts[0].startswith("h"):
+ layer_number = int(variable_name_parts[0][1:])
+ target_dict = params["blocks"][layer_number]
+
+ # Recursively access or create nested dictionaries
+ for key in variable_name_parts[1:-1]:
+ target_dict = target_dict.setdefault(key, {})
+
+ # Assign the variable array to the last key
+ last_key = variable_name_parts[-1]
+ target_dict[last_key] = variable_array
+
+ return params
diff --git a/ch05/09_extending-tokenizers/previous_chapters.py b/ch05/09_extending-tokenizers/previous_chapters.py
new file mode 100644
index 00000000..369e3700
--- /dev/null
+++ b/ch05/09_extending-tokenizers/previous_chapters.py
@@ -0,0 +1,279 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+#
+# This file collects all the relevant code that we covered thus far
+# throughout Chapters 2-4.
+# This file can be run as a standalone script.
+
+import tiktoken
+import torch
+import torch.nn as nn
+from torch.utils.data import Dataset, DataLoader
+
+#####################################
+# Chapter 2
+#####################################
+
+
+class GPTDatasetV1(Dataset):
+ def __init__(self, txt, tokenizer, max_length, stride):
+ self.input_ids = []
+ self.target_ids = []
+
+ # Tokenize the entire text
+ token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
+
+ # Use a sliding window to chunk the book into overlapping sequences of max_length
+ for i in range(0, len(token_ids) - max_length, stride):
+ input_chunk = token_ids[i:i + max_length]
+ target_chunk = token_ids[i + 1: i + max_length + 1]
+ self.input_ids.append(torch.tensor(input_chunk))
+ self.target_ids.append(torch.tensor(target_chunk))
+
+ def __len__(self):
+ return len(self.input_ids)
+
+ def __getitem__(self, idx):
+ return self.input_ids[idx], self.target_ids[idx]
+
+
+def create_dataloader_v1(txt, batch_size=4, max_length=256,
+ stride=128, shuffle=True, drop_last=True, num_workers=0):
+ # Initialize the tokenizer
+ tokenizer = tiktoken.get_encoding("gpt2")
+
+ # Create dataset
+ dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
+
+ # Create dataloader
+ dataloader = DataLoader(
+ dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
+
+ return dataloader
+
+
+#####################################
+# Chapter 3
+#####################################
+class MultiHeadAttention(nn.Module):
+ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
+ super().__init__()
+ assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
+
+ self.d_out = d_out
+ self.num_heads = num_heads
+ self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
+
+ self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
+ self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
+ self.dropout = nn.Dropout(dropout)
+ self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
+
+ def forward(self, x):
+ b, num_tokens, d_in = x.shape
+
+ keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
+ queries = self.W_query(x)
+ values = self.W_value(x)
+
+ # We implicitly split the matrix by adding a `num_heads` dimension
+ # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
+ keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
+ values = values.view(b, num_tokens, self.num_heads, self.head_dim)
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
+
+ # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
+ keys = keys.transpose(1, 2)
+ queries = queries.transpose(1, 2)
+ values = values.transpose(1, 2)
+
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
+ attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
+
+ # Original mask truncated to the number of tokens and converted to boolean
+ mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
+
+ # Use the mask to fill attention scores
+ attn_scores.masked_fill_(mask_bool, -torch.inf)
+
+ attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
+ attn_weights = self.dropout(attn_weights)
+
+ # Shape: (b, num_tokens, num_heads, head_dim)
+ context_vec = (attn_weights @ values).transpose(1, 2)
+
+ # Combine heads, where self.d_out = self.num_heads * self.head_dim
+ context_vec = context_vec.reshape(b, num_tokens, self.d_out)
+ context_vec = self.out_proj(context_vec) # optional projection
+
+ return context_vec
+
+
+#####################################
+# Chapter 4
+#####################################
+class LayerNorm(nn.Module):
+ def __init__(self, emb_dim):
+ super().__init__()
+ self.eps = 1e-5
+ self.scale = nn.Parameter(torch.ones(emb_dim))
+ self.shift = nn.Parameter(torch.zeros(emb_dim))
+
+ def forward(self, x):
+ mean = x.mean(dim=-1, keepdim=True)
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
+ return self.scale * norm_x + self.shift
+
+
+class GELU(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return 0.5 * x * (1 + torch.tanh(
+ torch.sqrt(torch.tensor(2.0 / torch.pi)) *
+ (x + 0.044715 * torch.pow(x, 3))
+ ))
+
+
+class FeedForward(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
+ GELU(),
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.att = MultiHeadAttention(
+ d_in=cfg["emb_dim"],
+ d_out=cfg["emb_dim"],
+ context_length=cfg["context_length"],
+ num_heads=cfg["n_heads"],
+ dropout=cfg["drop_rate"],
+ qkv_bias=cfg["qkv_bias"])
+ self.ff = FeedForward(cfg)
+ self.norm1 = LayerNorm(cfg["emb_dim"])
+ self.norm2 = LayerNorm(cfg["emb_dim"])
+ self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
+
+ def forward(self, x):
+ # Shortcut connection for attention block
+ shortcut = x
+ x = self.norm1(x)
+ x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
+ x = self.drop_shortcut(x)
+ x = x + shortcut # Add the original input back
+
+ # Shortcut connection for feed-forward block
+ shortcut = x
+ x = self.norm2(x)
+ x = self.ff(x)
+ x = self.drop_shortcut(x)
+ x = x + shortcut # Add the original input back
+
+ return x
+
+
+class GPTModel(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
+ self.drop_emb = nn.Dropout(cfg["drop_rate"])
+
+ self.trf_blocks = nn.Sequential(
+ *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
+
+ self.final_norm = LayerNorm(cfg["emb_dim"])
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
+
+ def forward(self, in_idx):
+ batch_size, seq_len = in_idx.shape
+ tok_embeds = self.tok_emb(in_idx)
+ pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
+ x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
+ x = self.drop_emb(x)
+ x = self.trf_blocks(x)
+ x = self.final_norm(x)
+ logits = self.out_head(x)
+ return logits
+
+
+def generate_text_simple(model, idx, max_new_tokens, context_size):
+ # idx is (B, T) array of indices in the current context
+ for _ in range(max_new_tokens):
+
+ # Crop current context if it exceeds the supported context size
+ # E.g., if LLM supports only 5 tokens, and the context size is 10
+ # then only the last 5 tokens are used as context
+ idx_cond = idx[:, -context_size:]
+
+ # Get the predictions
+ with torch.no_grad():
+ logits = model(idx_cond)
+
+ # Focus only on the last time step
+ # (batch, n_token, vocab_size) becomes (batch, vocab_size)
+ logits = logits[:, -1, :]
+
+ # Get the idx of the vocab entry with the highest logits value
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
+
+ # Append sampled index to the running sequence
+ idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
+
+ return idx
+
+
+if __name__ == "__main__":
+
+ GPT_CONFIG_124M = {
+ "vocab_size": 50257, # Vocabulary size
+ "context_length": 1024, # Context length
+ "emb_dim": 768, # Embedding dimension
+ "n_heads": 12, # Number of attention heads
+ "n_layers": 12, # Number of layers
+ "drop_rate": 0.1, # Dropout rate
+ "qkv_bias": False # Query-Key-Value bias
+ }
+
+ torch.manual_seed(123)
+ model = GPTModel(GPT_CONFIG_124M)
+ model.eval() # disable dropout
+
+ start_context = "Hello, I am"
+
+ tokenizer = tiktoken.get_encoding("gpt2")
+ encoded = tokenizer.encode(start_context)
+ encoded_tensor = torch.tensor(encoded).unsqueeze(0)
+
+ print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
+ print("\nInput text:", start_context)
+ print("Encoded input text:", encoded)
+ print("encoded_tensor.shape:", encoded_tensor.shape)
+
+ out = generate_text_simple(
+ model=model,
+ idx=encoded_tensor,
+ max_new_tokens=10,
+ context_size=GPT_CONFIG_124M["context_length"]
+ )
+ decoded_text = tokenizer.decode(out.squeeze(0).tolist())
+
+ print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
+ print("\nOutput:", out)
+ print("Output length:", len(out[0]))
+ print("Output text:", decoded_text)
diff --git a/ch07/01_main-chapter-code/exercise-solutions.ipynb b/ch07/01_main-chapter-code/exercise-solutions.ipynb
index 4533fc0d..b054203f 100644
--- a/ch07/01_main-chapter-code/exercise-solutions.ipynb
+++ b/ch07/01_main-chapter-code/exercise-solutions.ipynb
@@ -309,7 +309,30 @@
"Average score: 48.87\n",
"```\n",
"\n",
- "The score is close to 50, which is in the same ballpark as the score we previously achieved with the Alpaca-style prompts."
+ "The score is close to 50, which is in the same ballpark as the score we previously achieved with the Alpaca-style prompts.\n",
+ "\n",
+ "There is no inherent advantage or rationale why the Phi prompt-style should be better, but it can be more concise and efficient, except for the caveat mentioned in the *Tip* section below."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "156bc574-3f3e-4479-8f58-c8c8c472416e",
+ "metadata": {},
+ "source": [
+ "#### Tip: Considering special tokens"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "65cacf90-21c2-48f2-8f21-5c0c86749ff2",
+ "metadata": {},
+ "source": [
+ "- Note that the Phi-3 prompt template contains special tokens such as `<|user|>` and `<|assistant|>`, which can be suboptimal for the GPT-2 tokenizer\n",
+ "- While the GPT-2 tokenizer recognizes `<|endoftext|>` as a special token (encoded into token ID 50256), it is inefficient at handling other special tokens, such as the aforementioned ones\n",
+ "- For instance, `<|user|>` is encoded into 5 individual token IDs (27, 91, 7220, 91, 29), which is very inefficient\n",
+ "- We could add `<|user|>` as a new special token in `tiktoken` via the `allowed_special` argument, but please keep in mind that the GPT-2 vocabulary would not be able to handle it without additional modification\n",
+ "- If you are curious about how a tokenizer and LLM can be extended to handle special tokens, please see the [extend-tiktoken.ipynb](../../ch05/09_extending-tokenizers/extend-tiktoken.ipynb) bonus materials (note that this is not required here but is just an interesting/bonus consideration for curious readers)\n",
+ "- Furthermore, we can hypothesize that models that support these special tokens of a prompt template via their vocabulary may perform more efficiently and better overall"
]
},
{
@@ -994,7 +1017,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.11"
+ "version": "3.11.4"
}
},
"nbformat": 4,