diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py index 499bfa42da..594f217ab7 100644 --- a/examples/scripts/chat.py +++ b/examples/scripts/chat.py @@ -231,6 +231,26 @@ def load_model_and_tokenizer(args): return model, tokenizer +def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids): + if tokenizer.pad_token_id is None: + pad_token_id = tokenizer.eos_token_id + else: + pad_token_id = tokenizer.pad_token_id + + all_eos_token_ids = [] + + if eos_tokens is not None: + all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(","))) + + if eos_token_ids is not None: + all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")]) + + if len(all_eos_token_ids) == 0: + all_eos_token_ids.append(tokenizer.eos_token_id) + + return pad_token_id, all_eos_token_ids + + def chat_cli(): parser = TrlParser(ChatArguments) args = parser.parse_args_into_dataclasses()[0] @@ -252,6 +272,8 @@ def chat_cli(): model, tokenizer = load_model_and_tokenizer(args) generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) + pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) + interface = RichInterface(model_name=args.model_name_or_path, user_name=user) interface.clear() chat = clear_chat_history(current_args.system_prompt) @@ -322,8 +344,8 @@ def chat_cli(): top_k=current_args.top_k, top_p=current_args.top_p, repetition_penalty=current_args.repetition_penalty, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, + pad_token_id=pad_token_id, + eos_token_id=eos_token_ids, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index f511c20376..4a947d53e3 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -200,6 +200,14 @@ class ChatArguments: top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"}) top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"}) repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"}) + eos_tokens: str = field( + default=None, + metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated"}, + ) + eos_token_ids: str = field( + default=None, + metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated"}, + ) # model loading model_revision: str = field( default="main",