From 1a92a7725a8af14e61a1f4f0744069c64203f8ce Mon Sep 17 00:00:00 2001 From: Dustin Blackman Date: Wed, 20 Dec 2023 15:20:29 -0500 Subject: [PATCH] feat!: Default to first model in backend if none selected at start --- src/application/cli.rs | 16 +++++++--------- src/domain/services/actions.rs | 21 +++++++++++++++++---- src/domain/services/app_state.rs | 12 +++++++++--- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/application/cli.rs b/src/application/cli.rs index dd4901b..4d55ac1 100644 --- a/src/application/cli.rs +++ b/src/application/cli.rs @@ -149,7 +149,7 @@ fn subcommand_completions() -> Command { fn subcommand_sessions_delete() -> Command { return Command::new("delete") - .about("Delete one or all sessions") + .about("Delete one or all sessions.") .arg( clap::Arg::new("session-id") .short('i') @@ -160,7 +160,7 @@ fn subcommand_sessions_delete() -> Command { .arg( clap::Arg::new("all") .long("all") - .help("Delete all sessions") + .help("Delete all sessions.") .num_args(0), ) .group( @@ -187,7 +187,7 @@ fn arg_backend_health_check_timeout() -> Arg { .env("OATMEAL_BACKEND_HEALTH_CHECK_TIMEOUT") .num_args(1) .help( - "Time to wait in milliseconds before timing out when doing a healthcheck for a backend", + "Time to wait in milliseconds before timing out when doing a healthcheck for a backend.", ) .default_value("1000"); } @@ -198,8 +198,7 @@ fn arg_model() -> Arg { .long("model") .env("OATMEAL_MODEL") .num_args(1) - .help("The initial model on a backend to consume") - .default_value("llama2:latest"); + .help("The initial model on a backend to consume. Defaults to the first model available from the backend if not set."); } fn subcommand_chat() -> Command { @@ -432,10 +431,9 @@ pub async fn parse() -> Result { .get_one::("backend-health-check-timeout") .unwrap(), ); - Config::set( - ConfigKey::Model, - matches.get_one::("model").unwrap(), - ); + if let Some(model) = matches.get_one::("model") { + Config::set(ConfigKey::Model, model); + } } } diff --git a/src/domain/services/actions.rs b/src/domain/services/actions.rs index 7ee45a1..6858d6a 100644 --- a/src/domain/services/actions.rs +++ b/src/domain/services/actions.rs @@ -11,6 +11,7 @@ use crate::domain::models::AcceptType; use crate::domain::models::Action; use crate::domain::models::Author; use crate::domain::models::BackendBox; +use crate::domain::models::BackendPrompt; use crate::domain::models::EditorContext; use crate::domain::models::EditorName; use crate::domain::models::Event; @@ -208,6 +209,21 @@ fn worker_error(err: anyhow::Error, tx: &mpsc::UnboundedSender) -> Result return Ok(()); } +async fn completions( + backend: &BackendBox, + prompt: BackendPrompt, + tx: &mpsc::UnboundedSender, +) -> Result<()> { + if Config::get(ConfigKey::Model).is_empty() { + let models = backend.list_models().await?; + Config::set(ConfigKey::Model, &models[0]); + } + + backend.get_completion(prompt, tx).await?; + + return Ok(()); +} + fn help(tx: &mpsc::UnboundedSender) -> Result<()> { tx.send(Event::BackendMessage(Message::new( Author::Oatmeal, @@ -267,12 +283,9 @@ impl ActionsService { let backend_worker = backend_arc.clone(); worker = tokio::spawn(async move { - let res = backend_worker.get_completion(prompt, &worker_tx).await; - - if let Err(err) = res { + if let Err(err) = completions(&backend_worker, prompt, &worker_tx).await { worker_error(err, &worker_tx)?; } - return Ok(()); }); } diff --git a/src/domain/services/app_state.rs b/src/domain/services/app_state.rs index 8cd76a2..dddc88b 100644 --- a/src/domain/services/app_state.rs +++ b/src/domain/services/app_state.rs @@ -8,6 +8,8 @@ use super::CodeBlocks; use super::Scroll; use super::Sessions; use super::Themes; +use crate::config::Config; +use crate::config::ConfigKey; use crate::domain::models::AcceptType; use crate::domain::models::Action; use crate::domain::models::Author; @@ -56,7 +58,7 @@ impl<'a> AppState<'a> { } async fn init(props: AppStateProps) -> Result> { - let model_name = &props.model_name; + let mut model_name = props.model_name.to_string(); let theme = Themes::get(&props.theme_name, &props.theme_file)?; let mut app_state = AppState { @@ -84,7 +86,11 @@ impl<'a> AppState<'a> { )); } else { let models = props.backend.list_models().await?; - if !models.contains(&model_name.to_string()) { + if model_name.is_empty() { + model_name = models[0].to_string(); + // TODO refactor this out later. + Config::set(ConfigKey::Model, &model_name); + } else if !models.contains(&model_name.to_string()) { app_state .messages .push(Message::new_with_type( @@ -96,7 +102,7 @@ impl<'a> AppState<'a> { } // Fallback to the default intro message when there's no editor context. - if app_state.add_editor_context(props.editor).await.is_err() { + if app_state.add_editor_context(props.editor).await.is_err() && !model_name.is_empty() { app_state.messages.push(Message::new( Author::Model, "Hey there! What can I do for you?",