Skip to content

Commit

Permalink
feat!: Default to first model in backend if none selected at start
Browse files Browse the repository at this point in the history
  • Loading branch information
Dustin Blackman committed Dec 20, 2023
1 parent 154811e commit 1a92a77
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 16 deletions.
16 changes: 7 additions & 9 deletions src/application/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ fn subcommand_completions() -> Command {

fn subcommand_sessions_delete() -> Command {
return Command::new("delete")
.about("Delete one or all sessions")
.about("Delete one or all sessions.")
.arg(
clap::Arg::new("session-id")
.short('i')
Expand All @@ -160,7 +160,7 @@ fn subcommand_sessions_delete() -> Command {
.arg(
clap::Arg::new("all")
.long("all")
.help("Delete all sessions")
.help("Delete all sessions.")
.num_args(0),
)
.group(
Expand All @@ -187,7 +187,7 @@ fn arg_backend_health_check_timeout() -> Arg {
.env("OATMEAL_BACKEND_HEALTH_CHECK_TIMEOUT")
.num_args(1)
.help(
"Time to wait in milliseconds before timing out when doing a healthcheck for a backend",
"Time to wait in milliseconds before timing out when doing a healthcheck for a backend.",
)
.default_value("1000");
}
Expand All @@ -198,8 +198,7 @@ fn arg_model() -> Arg {
.long("model")
.env("OATMEAL_MODEL")
.num_args(1)
.help("The initial model on a backend to consume")
.default_value("llama2:latest");
.help("The initial model on a backend to consume. Defaults to the first model available from the backend if not set.");
}

fn subcommand_chat() -> Command {
Expand Down Expand Up @@ -432,10 +431,9 @@ pub async fn parse() -> Result<bool> {
.get_one::<String>("backend-health-check-timeout")
.unwrap(),
);
Config::set(
ConfigKey::Model,
matches.get_one::<String>("model").unwrap(),
);
if let Some(model) = matches.get_one::<String>("model") {
Config::set(ConfigKey::Model, model);
}
}
}

Expand Down
21 changes: 17 additions & 4 deletions src/domain/services/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::domain::models::AcceptType;
use crate::domain::models::Action;
use crate::domain::models::Author;
use crate::domain::models::BackendBox;
use crate::domain::models::BackendPrompt;
use crate::domain::models::EditorContext;
use crate::domain::models::EditorName;
use crate::domain::models::Event;
Expand Down Expand Up @@ -208,6 +209,21 @@ fn worker_error(err: anyhow::Error, tx: &mpsc::UnboundedSender<Event>) -> Result
return Ok(());
}

async fn completions(
backend: &BackendBox,
prompt: BackendPrompt,
tx: &mpsc::UnboundedSender<Event>,
) -> Result<()> {
if Config::get(ConfigKey::Model).is_empty() {
let models = backend.list_models().await?;
Config::set(ConfigKey::Model, &models[0]);
}

backend.get_completion(prompt, tx).await?;

return Ok(());
}

fn help(tx: &mpsc::UnboundedSender<Event>) -> Result<()> {
tx.send(Event::BackendMessage(Message::new(
Author::Oatmeal,
Expand Down Expand Up @@ -267,12 +283,9 @@ impl ActionsService {

let backend_worker = backend_arc.clone();
worker = tokio::spawn(async move {
let res = backend_worker.get_completion(prompt, &worker_tx).await;

if let Err(err) = res {
if let Err(err) = completions(&backend_worker, prompt, &worker_tx).await {
worker_error(err, &worker_tx)?;
}

return Ok(());
});
}
Expand Down
12 changes: 9 additions & 3 deletions src/domain/services/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use super::CodeBlocks;
use super::Scroll;
use super::Sessions;
use super::Themes;
use crate::config::Config;
use crate::config::ConfigKey;
use crate::domain::models::AcceptType;
use crate::domain::models::Action;
use crate::domain::models::Author;
Expand Down Expand Up @@ -56,7 +58,7 @@ impl<'a> AppState<'a> {
}

async fn init(props: AppStateProps) -> Result<AppState<'a>> {
let model_name = &props.model_name;
let mut model_name = props.model_name.to_string();
let theme = Themes::get(&props.theme_name, &props.theme_file)?;

let mut app_state = AppState {
Expand Down Expand Up @@ -84,7 +86,11 @@ impl<'a> AppState<'a> {
));
} else {
let models = props.backend.list_models().await?;
if !models.contains(&model_name.to_string()) {
if model_name.is_empty() {
model_name = models[0].to_string();
// TODO refactor this out later.
Config::set(ConfigKey::Model, &model_name);
} else if !models.contains(&model_name.to_string()) {
app_state
.messages
.push(Message::new_with_type(
Expand All @@ -96,7 +102,7 @@ impl<'a> AppState<'a> {
}

// Fallback to the default intro message when there's no editor context.
if app_state.add_editor_context(props.editor).await.is_err() {
if app_state.add_editor_context(props.editor).await.is_err() && !model_name.is_empty() {
app_state.messages.push(Message::new(
Author::Model,
"Hey there! What can I do for you?",
Expand Down

0 comments on commit 1a92a77

Please sign in to comment.