From 76b8a5aadc2974e98d27a239c16b99f878edbc83 Mon Sep 17 00:00:00 2001
From: vmpuri <45368418+vmpuri@users.noreply.github.com>
Date: Wed, 7 Aug 2024 19:13:10 -0700
Subject: [PATCH] Fix tokenization of chat interfaces
---
api/api.py | 48 +++++++++++++++++++++-----------------
generate.py | 67 +++++++++++++++++++++++++++++++++++++++++++----------
server.py | 8 +++++--
3 files changed, 88 insertions(+), 35 deletions(-)
diff --git a/api/api.py b/api/api.py
index 63135133b..e46c6a33e 100644
--- a/api/api.py
+++ b/api/api.py
@@ -10,6 +10,8 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
+import torch
+
from build.utils import device_sync
from generate import Generator, GeneratorArgs
@@ -222,7 +224,6 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)
- self.start_pos = 0
self.max_seq_length = (
self.model.config.max_seq_length
+ self.speculative_builder_args.speculate_k
@@ -257,20 +258,25 @@ def chunked_completion(self, completion_request: CompletionRequest):
CompletionResponseChunk objects in response to completion_request as tokens are generated.
"""
- device_sync(device=self.builder_args.device)
# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())
idx = 0
- buffer = []
- encoded = self.encode_tokens(
- completion_request.messages[-1].get("content"),
- bos=True,
- device=self.builder_args.device,
+ tokens = self.chat_formatter.encode_dialog_prompt(
+ dialog=[
+ {"role": message["role"], "content": message["content"]}
+ for message in completion_request.messages
+ ]
)
+
+ encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device)
+ print(self.tokenizer.decode(tokens))
+
+ start_pos = 0
+
generator_args = GeneratorArgs(
- completion_request.messages[-1].get("content"),
+ None,
max_new_tokens=(
int(completion_request.max_tokens)
if completion_request.max_tokens
@@ -279,20 +285,23 @@ def chunked_completion(self, completion_request: CompletionRequest):
encoded_prompt=encoded,
temperature=float(completion_request.temperature),
chat_mode=False,
+ sequential_prefill=True,
)
def callback(x, *, done_generating=False):
return self._callback(
x,
- buffer=buffer,
+ buffer=None,
done_generating=done_generating,
)
+ device_sync(device=self.builder_args.device)
+
# Process each token, metrics tuple yielded by Generator.generate.
for y, _ in self.generate(
- self.model,
- encoded,
- generator_args.max_new_tokens,
+ model=self.model,
+ prompt=encoded,
+ max_new_tokens=generator_args.max_new_tokens,
draft_model=self.draft_model,
speculate_k=generator_args.speculate_k,
chat_mode=generator_args.chat_mode,
@@ -300,12 +309,15 @@ def callback(x, *, done_generating=False):
temperature=generator_args.temperature,
top_k=generator_args.top_k,
sequential_prefill=generator_args.sequential_prefill,
- start_pos=self.start_pos,
+ start_pos=start_pos,
max_seq_length=self.max_seq_length,
seed=int(completion_request.seed),
):
if y is None:
continue
+ elif y.item() == self.tokenizer.eos_id:
+ # Stop generation if the EOS token is generated.
+ break
# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
content = "".join(
@@ -330,7 +342,7 @@ def callback(x, *, done_generating=False):
system_fingerprint=self.system_fingerprint,
)
yield chunk_response
- self.start_pos += y.size(0)
+ start_pos += y.size(0)
idx += 1
# Yield an ending chunk indicating the generation has completed.
@@ -369,10 +381,4 @@ def sync_completion(self, request: CompletionRequest):
)
def _callback(self, x, *, buffer, done_generating):
- period_id = self.tokenizer.encode(".")[0]
- buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
- if (
- self.is_llama3_model
- and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
- ):
- buffer = buffer[:-1] # drop the eot_id from the output buffer
+ pass
diff --git a/generate.py b/generate.py
index 48a77ba29..a0153291d 100644
--- a/generate.py
+++ b/generate.py
@@ -9,6 +9,8 @@
import os
import textwrap
import time
+
+from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
@@ -28,24 +30,33 @@
from cli import add_arguments_for_verb, arg_init, check_args
from utils.device_info import get_device_info
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<>", "<>"
-
-class ChatFormat:
+class _ChatFormatter(ABC):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
- def encode_header(self, message) -> List[int]:
+ @abstractmethod
+ def encode_dialog_prompt(self, dialog) -> List[int]:
+ raise NotImplementedError()
+
+
+class Llama3ChatFormatter(_ChatFormatter):
+ """Format a chat prompt using special tokens to demarcate roles and messages.
+
+ Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
+
+ """
+
+ def encode_header(self, role) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
- tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
+ tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_message(self, message) -> List[int]:
- tokens = self.encode_header(message)
+ tokens = self.encode_header(message.role)
tokens.extend(
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
)
@@ -62,9 +73,37 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
return tokens
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>", "<>"
+
+
+class Llama2ChatFormatter(_ChatFormatter):
+ def encode_dialog_prompt(self, dialog) -> List[int]:
+ tokens = self.tokenizer.encode(f"{B_INST} ")
+ first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
+ for message in dialog:
+ content = message["content"].strip()
+ if message["role"] == "system":
+ encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
+ first_message = False
+ elif message["role"] == "user":
+ encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
+ f"{B_INST if first_message else ''} {content} {E_INST} "
+ )
+ first_message = True
+ elif message["role"] == "assistant":
+ encoded = self.tokenizer.encode(f"{content}\n\n") + [
+ self.tokenizer.eos_id()
+ ]
+ tokens += encoded
+ return tokens
+
+
@dataclass
class GeneratorArgs:
- prompt: str = "torchchat is pronounced torch-chat and is so cool because"
+ prompt: Optional[str] = (
+ None # When passed into the Generator, this will be used as the system prompt
+ )
encoded_prompt: Optional[torch.Tensor] = None
chat_mode: bool = False
gui_mode: bool = False
@@ -188,7 +227,7 @@ def __init__(
))
# fmt: on
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")
-
+ self.system_prompt = generator_args.prompt
self.tokenizer = _initialize_tokenizer(self.tokenizer_args)
# Right now the assumption is only llama3 uses tiktokenizer and it
@@ -200,6 +239,11 @@ def __init__(
logging.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
+ self.chat_formatter = (
+ Llama3ChatFormatter(self.tokenizer)
+ if self.is_llama3_model
+ else Llama2ChatFormatter(self.tokenizer)
+ )
self.builder_args.setup_caches = False
self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
@@ -641,8 +685,7 @@ def chat(
)
if get_system_prompt == "y" or get_system_prompt == "Y":
self.system_prompt = input("What is your system prompt? \n")
- if self.is_llama3_model:
- self.chat_formatter = ChatFormat(self.tokenizer)
+
else:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
@@ -685,7 +728,7 @@ def chat(
prompt, bos=True, device=self.builder_args.device
)
else:
- if self.system_prompt is not None:
+ if self.system_prompt:
encoded = self.chat_formatter.encode_dialog_prompt(
[
{"role": "system", "content": self.system_prompt},
diff --git a/server.py b/server.py
index 106c30134..074df6646 100644
--- a/server.py
+++ b/server.py
@@ -6,6 +6,10 @@
import json
+import logging
+
+logger = logging.getLogger(__name__)
+
from dataclasses import asdict
from typing import Dict, List, Union
@@ -21,7 +25,7 @@
OPENAI_API_VERSION = "v1"
-def create_app(args):
+def create_app(args): # noqa: C901
"""
Creates a flask app that can be used to serve the model as a chat API.
"""
@@ -69,7 +73,7 @@ def chunk_processor(chunked_completion_generator):
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
- print(next_tok, end="")
+ print(next_tok, end="", flush=True)
yield json.dumps(_del_none(asdict(chunk)))
return Response(