Skip to content

Commit

Permalink
fix: Clear backend context when switching models
Browse files Browse the repository at this point in the history
  • Loading branch information
Dustin Blackman committed Nov 21, 2023
1 parent 6785c53 commit 3b8f7c8
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 12 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ dev = '''set -e
'''

dev-install = '''set -e
cargo build
cargo build --release
sudo rm -f /usr/local/bin/oatmeal
sudo mv ./target/debug/oatmeal /usr/local/bin/
sudo mv ./target/release/oatmeal /usr/local/bin/
'''

goreleaser = '''set -e
Expand Down
11 changes: 2 additions & 9 deletions src/application/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,8 @@ async fn start_loop<B: Backend>(
let mut prompt =
BackendPrompt::new(input_str.to_string(), app_state.backend_context.clone());

let user_messages_length = app_state
.messages
.iter()
.filter(|m| {
return m.author == Author::User && SlashCommand::parse(&m.text).is_none();
})
.collect::<Vec<_>>()
.len();
if user_messages_length == 1 {
if app_state.backend_context.is_empty() && SlashCommand::parse(&input_str).is_none()
{
prompt.append_system_prompt(&app_state.editor_context);
}

Expand Down
2 changes: 1 addition & 1 deletion src/domain/models/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub trait Backend {
/// Upon receiving all results, a final `done` boolean
/// is provided as the last message to the channel.
///
/// In order for a backend to maintain history, a context array is usually
/// In order for a backend to maintain history, a context array must be
/// provided by the backend. This can be passed alongside the `done`
/// boolean, and will be provided on the next prompt to the backend.
async fn get_completion<'a>(
Expand Down
14 changes: 14 additions & 0 deletions src/domain/services/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ impl<'a> AppState<'a> {
self.backend_context = ctx;
}

if self.backend_context.is_empty() {
self.add_message(Message::new_with_type(
Author::Oatmeal,
MessageType::Error,
"Error: No context was provided by the backend upon completion. Please report this bug on Github."
));
self.sync_dependants();
}

self.codeblocks.replace_from_messages(&self.messages);
}
}
Expand Down Expand Up @@ -209,6 +218,11 @@ impl<'a> AppState<'a> {
tx.send(Action::CopyMessages(self.messages.clone()))?;
self.waiting_for_backend = true;
}

// Reset backend context on model switch.
if command.is_model_set() {
self.backend_context = "".to_string();
}
}

return Ok((should_break, should_continue));
Expand Down
43 changes: 43 additions & 0 deletions src/domain/services/app_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::AppState;
use crate::domain::models::AcceptType;
use crate::domain::models::Action;
use crate::domain::models::Author;
use crate::domain::models::BackendResponse;
use crate::domain::models::Message;
use crate::domain::models::MessageType;
use crate::domain::services::BubbleList;
Expand Down Expand Up @@ -182,3 +183,45 @@ mod handle_slash_commands {
return Ok(());
}
}

mod handle_backend_response {
use super::*;

#[test]
fn it_handles_new_backend_response() {
let mut app_state = AppState::default();
app_state
.messages
.push(Message::new(Author::User, "Do something for me!"));
let backend_response = BackendResponse {
author: Author::Model,
text: "All done!".to_string(),
done: true,
context: Some("icanrememberthingsnow".to_string()),
};
app_state.handle_backend_response(backend_response);

assert_eq!(app_state.messages.len(), 2);
}

#[test]
fn it_handles_bad_backend_response() {
let mut app_state = AppState::default();
app_state
.messages
.push(Message::new(Author::User, "Do something for me!"));
let backend_response = BackendResponse {
author: Author::Model,
text: "All done!".to_string(),
done: true,
context: Some("".to_string()),
};
app_state.handle_backend_response(backend_response);

assert_eq!(app_state.messages.len(), 3);
assert_eq!(
app_state.messages.last().unwrap().message_type(),
MessageType::Error
);
}
}

0 comments on commit 3b8f7c8

Please sign in to comment.