-
Notifications
You must be signed in to change notification settings - Fork 369
allow chat to halt new token generation on stop_sequence
#364
Conversation
binaries/llm-cli/src/main.rs
Outdated
@@ -264,6 +266,33 @@ fn interactive( | |||
.unwrap_or(false) | |||
} | |||
|
|||
fn inference_callback( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is essentially a copy of the function here https://github.com/rustformers/llm/blob/main/crates/llm/examples/vicuna-chat.rs#L119
binaries/llm-cli/src/main.rs
Outdated
@@ -256,6 +256,12 @@ fn interactive( | |||
let parameters = generate.inference_parameters(model.eot_token_id()); | |||
let mut rng = generate.rng(); | |||
|
|||
let stop_sequence = message_prompt_template |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am assuming the message_prompt_template is something like:
User: {{PROMPT}}
But very open to suggestions here on how to make this more robust
binaries/llm-cli/src/main.rs
Outdated
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, Infallible> + '_ { | ||
move |resp| match resp { | ||
InferenceResponse::InferredToken(t) => { | ||
if chat_mode { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i didn't touch REPL mode - is the desired behavior to continue generating tokens in that mode? otherwise happy to change
stop_sequence
Great work! I was actually thinking about bringing that logic out so that I could use it for Do you think you'd be able to move the That would allow for this logic to be used across both, as well as elsewhere (the aforementioned I'd suggest passing the stop sequence in from the CLI (i.e. maybe replace
Yup, that's for a back and forth where no state is preserved and the model can produce as much output as it wants. I'd suggest splitting the code paths so that they use entirely different inference callbacks - they share the readline, but their inference behaviour is pretty different. Great work once again, let me know if you need a hand with any of this! 🙂 |
Thanks @philpax! Just pushed an update where I moved the function to |
Great work! Looking forward to merging this soon 🚀
I don't think the logic would work if it was postfix anyway - we should make it clear that you need to pass in a prefix. I'd say you can leave out the |
Okay @philpax I have updated the CLI to take a |
The previous abstraction made it hard to reason about what each codepath would do. To resolve this, I've split the code up and now have separate functions entirely that share code.
Heya! ...apologies for hijacking the PR. I went to test it and all of your changes worked as expected, but I realised that there were quite a few latent bugs with the stuff not covered by your PR and that the whole chat/REPL logic just wasn't working how I wanted it to work. I ended up revising way more than I intended 😅 The upshot is that it should now work consistently, and there shouldn't be any surprise discrepancies between REPL and chat mode. Sorry once again for the complete hijack 😭 Feel free to ask about any of the changes I made! Most of them were unrelated to the code you introduced (I mostly addressed issues that were already present before your changes), but I'm happy to explain them nonetheless. You might be interested in the |
eyre::bail!( | ||
"Must specify either --message-prompt-prefix or --message-prompt-prefix-file" | ||
) | ||
(None, Some(message_prompt_prefix_file)) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow very cool!
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a { | ||
let mut stop_sequence_buf = String::new(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting, scoping this buffer to the function seems a lot better!
@@ -897,8 +897,8 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( | |||
|
|||
/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated. | |||
/// This callback is used in [InferenceSession::infer] in chat_mode. | |||
pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( | |||
stop_sequence: String, | |||
pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do these do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Rust, objects get the Send
trait if they can be sent across threads, and Sync
if they can be used by multiple threads (you can see more details here).
I needed to add this because eyre
, which we use for error reporting in the CLI, expects the error from infer
to be Send + Sync
. The error is passed down from callback
to infer
, so the trait requirements need to be updated across the library.
Assistant: How may I help you? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some prompts I made for Falcon and MPT too since I was testing that. Want me to add in a follow-up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing!
) | ||
} | ||
|
||
fn interactive( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah love that you split these out into separate functions!
Closes #363
Stop token generation after reaching a specified
stop_sequence
in chat modeI am still new to rust so please let me know how I can improve my code!