Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat tokenization fixes in generate.py & API #1035

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import torch

from build.utils import device_sync

from generate import Generator, GeneratorArgs
Expand Down Expand Up @@ -222,7 +224,6 @@ def __init__(self, *args, **kwargs):
"""

super().__init__(*args, **kwargs)
self.start_pos = 0
self.max_seq_length = (
self.model.config.max_seq_length
+ self.speculative_builder_args.speculate_k
Expand Down Expand Up @@ -257,20 +258,25 @@ def chunked_completion(self, completion_request: CompletionRequest):
CompletionResponseChunk objects in response to completion_request as tokens are generated.

"""
device_sync(device=self.builder_args.device)

# Initialize counters for chunk responses and encode the prompt.
id = str(uuid.uuid4())

idx = 0
buffer = []
encoded = self.encode_tokens(
completion_request.messages[-1].get("content"),
bos=True,
device=self.builder_args.device,
tokens = self.chat_formatter.encode_dialog_prompt(
dialog=[
{"role": message["role"], "content": message["content"]}
for message in completion_request.messages
]
)

encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device)
print(self.tokenizer.decode(tokens))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking that this is an intentional print

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - this prints out the prompt on the server side so that it's easy to track the full prompt solely from the server side.

However, this raises a larger issue in the generate/API stack - we need to replace print statements with a logger so that users can choose not to print these debug messages.


start_pos = 0

generator_args = GeneratorArgs(
completion_request.messages[-1].get("content"),
None,
max_new_tokens=(
int(completion_request.max_tokens)
if completion_request.max_tokens
Expand All @@ -279,33 +285,39 @@ def chunked_completion(self, completion_request: CompletionRequest):
encoded_prompt=encoded,
temperature=float(completion_request.temperature),
chat_mode=False,
sequential_prefill=True,
)

def callback(x, *, done_generating=False):
return self._callback(
x,
buffer=buffer,
buffer=None,
done_generating=done_generating,
)

device_sync(device=self.builder_args.device)

# Process each token, metrics tuple yielded by Generator.generate.
for y, _ in self.generate(
self.model,
encoded,
generator_args.max_new_tokens,
model=self.model,
prompt=encoded,
max_new_tokens=generator_args.max_new_tokens,
draft_model=self.draft_model,
speculate_k=generator_args.speculate_k,
chat_mode=generator_args.chat_mode,
callback=callback,
temperature=generator_args.temperature,
top_k=generator_args.top_k,
sequential_prefill=generator_args.sequential_prefill,
start_pos=self.start_pos,
start_pos=start_pos,
max_seq_length=self.max_seq_length,
seed=int(completion_request.seed),
):
if y is None:
continue
elif y.item() == self.tokenizer.eos_id:
# Stop generation if the EOS token is generated.
break

# Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token.
content = "".join(
Expand All @@ -330,7 +342,7 @@ def callback(x, *, done_generating=False):
system_fingerprint=self.system_fingerprint,
)
yield chunk_response
self.start_pos += y.size(0)
start_pos += y.size(0)
idx += 1

# Yield an ending chunk indicating the generation has completed.
Expand Down Expand Up @@ -369,10 +381,4 @@ def sync_completion(self, request: CompletionRequest):
)

def _callback(self, x, *, buffer, done_generating):
period_id = self.tokenizer.encode(".")[0]
buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:])
if (
self.is_llama3_model
and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]
):
buffer = buffer[:-1] # drop the eot_id from the output buffer
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this is a pass again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback function is only used in generate() for the CLI interactive chat to print results to stdout. I initially copied this code naively when refactoring the original generate.py and copied it over to openaiapi where it isn't used.

67 changes: 55 additions & 12 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import os
import textwrap
import time

from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
Expand All @@ -28,24 +30,33 @@
from cli import add_arguments_for_verb, arg_init, check_args
from utils.device_info import get_device_info

B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"


class ChatFormat:
class _ChatFormatter(ABC):
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def encode_header(self, message) -> List[int]:
@abstractmethod
def encode_dialog_prompt(self, dialog) -> List[int]:
raise NotImplementedError()


class Llama3ChatFormatter(_ChatFormatter):
"""Format a chat prompt using special tokens to demarcate roles and messages.

Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3

"""

def encode_header(self, role) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens

def encode_message(self, message) -> List[int]:
tokens = self.encode_header(message)
tokens = self.encode_header(message.role)
tokens.extend(
self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
)
Expand All @@ -62,9 +73,37 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
return tokens


B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"


class Llama2ChatFormatter(_ChatFormatter):
def encode_dialog_prompt(self, dialog) -> List[int]:
tokens = self.tokenizer.encode(f"{B_INST} ")
first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
for message in dialog:
content = message["content"].strip()
if message["role"] == "system":
encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
first_message = False
elif message["role"] == "user":
encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
f"{B_INST if first_message else ''} {content} {E_INST} "
)
first_message = True
elif message["role"] == "assistant":
encoded = self.tokenizer.encode(f"{content}\n\n") + [
self.tokenizer.eos_id()
]
tokens += encoded
return tokens


@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
prompt: Optional[str] = (
None # When passed into the Generator, this will be used as the system prompt
)
encoded_prompt: Optional[torch.Tensor] = None
chat_mode: bool = False
gui_mode: bool = False
Expand Down Expand Up @@ -188,7 +227,7 @@ def __init__(
))
# fmt: on
# raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.")

self.system_prompt = generator_args.prompt
self.tokenizer = _initialize_tokenizer(self.tokenizer_args)

# Right now the assumption is only llama3 uses tiktokenizer and it
Expand All @@ -200,6 +239,11 @@ def __init__(
logging.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
self.chat_formatter = (
Llama3ChatFormatter(self.tokenizer)
if self.is_llama3_model
else Llama2ChatFormatter(self.tokenizer)
)

self.builder_args.setup_caches = False
self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
Expand Down Expand Up @@ -641,8 +685,7 @@ def chat(
)
if get_system_prompt == "y" or get_system_prompt == "Y":
self.system_prompt = input("What is your system prompt? \n")
if self.is_llama3_model:
self.chat_formatter = ChatFormat(self.tokenizer)

else:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
Expand Down Expand Up @@ -685,7 +728,7 @@ def chat(
prompt, bos=True, device=self.builder_args.device
)
else:
if self.system_prompt is not None:
if self.system_prompt:
encoded = self.chat_formatter.encode_dialog_prompt(
[
{"role": "system", "content": self.system_prompt},
Expand Down
8 changes: 6 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

import json

import logging

logger = logging.getLogger(__name__)

from dataclasses import asdict
from typing import Dict, List, Union

Expand All @@ -21,7 +25,7 @@
OPENAI_API_VERSION = "v1"


def create_app(args):
def create_app(args): # noqa: C901
"""
Creates a flask app that can be used to serve the model as a chat API.
"""
Expand Down Expand Up @@ -69,7 +73,7 @@ def chunk_processor(chunked_completion_generator):
for chunk in chunked_completion_generator:
if (next_tok := chunk.choices[0].delta.content) is None:
next_tok = ""
print(next_tok, end="")
print(next_tok, end="", flush=True)
yield json.dumps(_del_none(asdict(chunk)))

return Response(
Expand Down
Loading