From 76b8a5aadc2974e98d27a239c16b99f878edbc83 Mon Sep 17 00:00:00 2001 From: vmpuri <45368418+vmpuri@users.noreply.github.com> Date: Wed, 7 Aug 2024 19:13:10 -0700 Subject: [PATCH] Fix tokenization of chat interfaces --- api/api.py | 48 +++++++++++++++++++++----------------- generate.py | 67 +++++++++++++++++++++++++++++++++++++++++++---------- server.py | 8 +++++-- 3 files changed, 88 insertions(+), 35 deletions(-) diff --git a/api/api.py b/api/api.py index 63135133b..e46c6a33e 100644 --- a/api/api.py +++ b/api/api.py @@ -10,6 +10,8 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union +import torch + from build.utils import device_sync from generate import Generator, GeneratorArgs @@ -222,7 +224,6 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.start_pos = 0 self.max_seq_length = ( self.model.config.max_seq_length + self.speculative_builder_args.speculate_k @@ -257,20 +258,25 @@ def chunked_completion(self, completion_request: CompletionRequest): CompletionResponseChunk objects in response to completion_request as tokens are generated. """ - device_sync(device=self.builder_args.device) # Initialize counters for chunk responses and encode the prompt. id = str(uuid.uuid4()) idx = 0 - buffer = [] - encoded = self.encode_tokens( - completion_request.messages[-1].get("content"), - bos=True, - device=self.builder_args.device, + tokens = self.chat_formatter.encode_dialog_prompt( + dialog=[ + {"role": message["role"], "content": message["content"]} + for message in completion_request.messages + ] ) + + encoded = torch.tensor(tokens, dtype=torch.int, device=self.builder_args.device) + print(self.tokenizer.decode(tokens)) + + start_pos = 0 + generator_args = GeneratorArgs( - completion_request.messages[-1].get("content"), + None, max_new_tokens=( int(completion_request.max_tokens) if completion_request.max_tokens @@ -279,20 +285,23 @@ def chunked_completion(self, completion_request: CompletionRequest): encoded_prompt=encoded, temperature=float(completion_request.temperature), chat_mode=False, + sequential_prefill=True, ) def callback(x, *, done_generating=False): return self._callback( x, - buffer=buffer, + buffer=None, done_generating=done_generating, ) + device_sync(device=self.builder_args.device) + # Process each token, metrics tuple yielded by Generator.generate. for y, _ in self.generate( - self.model, - encoded, - generator_args.max_new_tokens, + model=self.model, + prompt=encoded, + max_new_tokens=generator_args.max_new_tokens, draft_model=self.draft_model, speculate_k=generator_args.speculate_k, chat_mode=generator_args.chat_mode, @@ -300,12 +309,15 @@ def callback(x, *, done_generating=False): temperature=generator_args.temperature, top_k=generator_args.top_k, sequential_prefill=generator_args.sequential_prefill, - start_pos=self.start_pos, + start_pos=start_pos, max_seq_length=self.max_seq_length, seed=int(completion_request.seed), ): if y is None: continue + elif y.item() == self.tokenizer.eos_id: + # Stop generation if the EOS token is generated. + break # Decode the torch.Tensor token to a string and append to the buffer. Separate the sequences with a period token. content = "".join( @@ -330,7 +342,7 @@ def callback(x, *, done_generating=False): system_fingerprint=self.system_fingerprint, ) yield chunk_response - self.start_pos += y.size(0) + start_pos += y.size(0) idx += 1 # Yield an ending chunk indicating the generation has completed. @@ -369,10 +381,4 @@ def sync_completion(self, request: CompletionRequest): ) def _callback(self, x, *, buffer, done_generating): - period_id = self.tokenizer.encode(".")[0] - buffer.append(self.tokenizer.decode([period_id] + x.tolist())[1:]) - if ( - self.is_llama3_model - and x.item() == self.tokenizer.special_tokens["<|eot_id|>"] - ): - buffer = buffer[:-1] # drop the eot_id from the output buffer + pass diff --git a/generate.py b/generate.py index 48a77ba29..a0153291d 100644 --- a/generate.py +++ b/generate.py @@ -9,6 +9,8 @@ import os import textwrap import time + +from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple @@ -28,24 +30,33 @@ from cli import add_arguments_for_verb, arg_init, check_args from utils.device_info import get_device_info -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>", "<>" - -class ChatFormat: +class _ChatFormatter(ABC): def __init__(self, tokenizer): self.tokenizer = tokenizer - def encode_header(self, message) -> List[int]: + @abstractmethod + def encode_dialog_prompt(self, dialog) -> List[int]: + raise NotImplementedError() + + +class Llama3ChatFormatter(_ChatFormatter): + """Format a chat prompt using special tokens to demarcate roles and messages. + + Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 + + """ + + def encode_header(self, role) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) - tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) + tokens.extend(self.tokenizer.encode(role, bos=False, eos=False)) tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) return tokens def encode_message(self, message) -> List[int]: - tokens = self.encode_header(message) + tokens = self.encode_header(message.role) tokens.extend( self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) ) @@ -62,9 +73,37 @@ def encode_dialog_prompt(self, dialog) -> List[int]: return tokens +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>", "<>" + + +class Llama2ChatFormatter(_ChatFormatter): + def encode_dialog_prompt(self, dialog) -> List[int]: + tokens = self.tokenizer.encode(f"{B_INST} ") + first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it. + for message in dialog: + content = message["content"].strip() + if message["role"] == "system": + encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}") + first_message = False + elif message["role"] == "user": + encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode( + f"{B_INST if first_message else ''} {content} {E_INST} " + ) + first_message = True + elif message["role"] == "assistant": + encoded = self.tokenizer.encode(f"{content}\n\n") + [ + self.tokenizer.eos_id() + ] + tokens += encoded + return tokens + + @dataclass class GeneratorArgs: - prompt: str = "torchchat is pronounced torch-chat and is so cool because" + prompt: Optional[str] = ( + None # When passed into the Generator, this will be used as the system prompt + ) encoded_prompt: Optional[torch.Tensor] = None chat_mode: bool = False gui_mode: bool = False @@ -188,7 +227,7 @@ def __init__( )) # fmt: on # raise RuntimeError("You need to use --is-chat-model to indicate model has chat support.") - + self.system_prompt = generator_args.prompt self.tokenizer = _initialize_tokenizer(self.tokenizer_args) # Right now the assumption is only llama3 uses tiktokenizer and it @@ -200,6 +239,11 @@ def __init__( logging.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) + self.chat_formatter = ( + Llama3ChatFormatter(self.tokenizer) + if self.is_llama3_model + else Llama2ChatFormatter(self.tokenizer) + ) self.builder_args.setup_caches = False self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer) @@ -641,8 +685,7 @@ def chat( ) if get_system_prompt == "y" or get_system_prompt == "Y": self.system_prompt = input("What is your system prompt? \n") - if self.is_llama3_model: - self.chat_formatter = ChatFormat(self.tokenizer) + else: max_seq_length = min( encoded.size(0) + generator_args.max_new_tokens, @@ -685,7 +728,7 @@ def chat( prompt, bos=True, device=self.builder_args.device ) else: - if self.system_prompt is not None: + if self.system_prompt: encoded = self.chat_formatter.encode_dialog_prompt( [ {"role": "system", "content": self.system_prompt}, diff --git a/server.py b/server.py index 106c30134..074df6646 100644 --- a/server.py +++ b/server.py @@ -6,6 +6,10 @@ import json +import logging + +logger = logging.getLogger(__name__) + from dataclasses import asdict from typing import Dict, List, Union @@ -21,7 +25,7 @@ OPENAI_API_VERSION = "v1" -def create_app(args): +def create_app(args): # noqa: C901 """ Creates a flask app that can be used to serve the model as a chat API. """ @@ -69,7 +73,7 @@ def chunk_processor(chunked_completion_generator): for chunk in chunked_completion_generator: if (next_tok := chunk.choices[0].delta.content) is None: next_tok = "" - print(next_tok, end="") + print(next_tok, end="", flush=True) yield json.dumps(_del_none(asdict(chunk))) return Response(