From 2f4ba2dbd74d7894b7c47be3c351e017b1a4cc5a Mon Sep 17 00:00:00 2001 From: Philip Rideout Date: Sun, 25 Aug 2024 14:11:37 -0700 Subject: [PATCH] Fix assertion failure during chat session. (#1061) This fixes the following assert that is easy to repro in any chat session: ``` Traceback (most recent call last): File "/home/ubuntu/cali/torchchat/torchchat.py", line 69, in generate_main(args) File "/home/ubuntu/cali/torchchat/generate.py", line 896, in main for _ in gen.chat(generator_args): File "/home/ubuntu/cali/torchchat/generate.py", line 748, in chat self.chat_formatter.encode_header( File "/home/ubuntu/cali/torchchat/generate.py", line 53, in encode_header tokens.extend(self.tokenizer.encode(role, bos=False, eos=False)) File "/home/ubuntu/cali/torchchat/tokenizer/tiktoken.py", line 133, in encode assert type(s) is str ``` I believe this regressed with https://github.com/pytorch/torchchat/pull/1035. --- generate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/generate.py b/generate.py index 3f0c02b5f..cb0048d10 100644 --- a/generate.py +++ b/generate.py @@ -745,9 +745,7 @@ def chat( {"role": "user", "content": prompt} ) encoded.extend( - self.chat_formatter.encode_header( - {"role": "assistant", "content": ""} - ) + self.chat_formatter.encode_header("assistant") ) encoded = torch.tensor( encoded, dtype=torch.int, device=self.builder_args.device