From 1d52f1bba1eb95228df240d7598607b7414923f2 Mon Sep 17 00:00:00 2001 From: Philip Rideout Date: Sat, 24 Aug 2024 18:26:08 -0700 Subject: [PATCH] Fix assertion failure during chat session. This fixes the following assert that is easy to repro in any chat session: ``` Traceback (most recent call last): File "/home/ubuntu/cali/torchchat/torchchat.py", line 69, in generate_main(args) File "/home/ubuntu/cali/torchchat/generate.py", line 896, in main for _ in gen.chat(generator_args): File "/home/ubuntu/cali/torchchat/generate.py", line 748, in chat self.chat_formatter.encode_header( File "/home/ubuntu/cali/torchchat/generate.py", line 53, in encode_header tokens.extend(self.tokenizer.encode(role, bos=False, eos=False)) File "/home/ubuntu/cali/torchchat/tokenizer/tiktoken.py", line 133, in encode assert type(s) is str ``` I believe this regressed with https://github.com/pytorch/torchchat/pull/1035. --- generate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/generate.py b/generate.py index 3f0c02b5f..cb0048d10 100644 --- a/generate.py +++ b/generate.py @@ -745,9 +745,7 @@ def chat( {"role": "user", "content": prompt} ) encoded.extend( - self.chat_formatter.encode_header( - {"role": "assistant", "content": ""} - ) + self.chat_formatter.encode_header("assistant") ) encoded = torch.tensor( encoded, dtype=torch.int, device=self.builder_args.device