diff --git a/generate.py b/generate.py index 3f0c02b5f..cb0048d10 100644 --- a/generate.py +++ b/generate.py @@ -745,9 +745,7 @@ def chat( {"role": "user", "content": prompt} ) encoded.extend( - self.chat_formatter.encode_header( - {"role": "assistant", "content": ""} - ) + self.chat_formatter.encode_header("assistant") ) encoded = torch.tensor( encoded, dtype=torch.int, device=self.builder_args.device