Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

allow chat to halt new token generation on stop_sequence #364

Merged
merged 13 commits into from
Jul 12, 2023
28 changes: 14 additions & 14 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ impl InferenceSession {
}

/// Feed a prompt to the model for this session.
pub fn feed_prompt<'a, E: std::error::Error + 'static, P: Into<Prompt<'a>>>(
pub fn feed_prompt<'a, E: std::error::Error + Send + Sync + 'static, P: Into<Prompt<'a>>>(
&mut self,
model: &dyn Model,
params: &InferenceParameters,
Expand Down Expand Up @@ -407,7 +407,7 @@ impl InferenceSession {
/// generated (specified by [InferenceRequest::maximum_token_count]).
///
/// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token].
pub fn infer<E: std::error::Error + 'static>(
pub fn infer<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
rng: &mut impl rand::Rng,
Expand Down Expand Up @@ -440,13 +440,13 @@ impl InferenceSession {
// Feed the initial prompt through the transformer, to update its
// context window with new data, if necessary.
if !request.prompt.is_empty() {
self.feed_prompt(
model,
parameters,
request.prompt,
output_request,
feed_prompt_callback(&mut callback),
)?;
self.feed_prompt(
model,
parameters,
request.prompt,
output_request,
feed_prompt_callback(&mut callback),
)?;
}
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;
Expand Down Expand Up @@ -663,7 +663,7 @@ pub enum InferenceError {
EndOfText,
#[error("the user-specified callback returned an error")]
/// The user-specified callback returned an error.
UserCallback(Box<dyn std::error::Error>),
UserCallback(Box<dyn std::error::Error + Send + Sync>),
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -885,7 +885,7 @@ pub enum InferenceFeedback {

/// Adapt an [InferenceResponse] callback so that it can be used in a call to
/// [InferenceSession::feed_prompt].
pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
pub fn feed_prompt_callback<'a, E: std::error::Error + Send + Sync + 'static>(
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a,
) -> impl FnMut(&[u8]) -> Result<InferenceFeedback, E> + 'a {
let mut buffer = TokenUtf8Buffer::new();
Expand All @@ -897,8 +897,8 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(

/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated.
/// This callback is used in [InferenceSession::infer] in chat_mode.
pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>(
stop_sequence: String,
pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do these do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Rust, objects get the Send trait if they can be sent across threads, and Sync if they can be used by multiple threads (you can see more details here).

I needed to add this because eyre, which we use for error reporting in the CLI, expects the error from infer to be Send + Sync. The error is passed down from callback to infer, so the trait requirements need to be updated across the library.

stop_sequence: &'a str,
mut callback: impl FnMut(String) + 'a,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a {
let mut stop_sequence_buf = String::new();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, scoping this buffer to the function seems a lot better!

Expand All @@ -908,7 +908,7 @@ pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>(
let mut buf = stop_sequence_buf.clone();
buf.push_str(&token);

if buf.starts_with(&stop_sequence) {
if buf.starts_with(stop_sequence) {
// We've generated the stop sequence, so we're done.
// Note that this will contain the extra tokens that were generated after the stop sequence,
// which may affect generation. This is non-ideal, but it's the best we can do without
Expand Down