From a847974573126e5ead3c223ad69a5bc8cf87ca06 Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:44:40 -0600 Subject: [PATCH 01/16] feat: Add gemini enums --- src/configuration/config.rs | 4 ++++ src/domain/models/backend.rs | 1 + 2 files changed, 5 insertions(+) diff --git a/src/configuration/config.rs b/src/configuration/config.rs index a245eb3..431563c 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -33,6 +33,8 @@ pub enum ConfigKey { OllamaURL, OpenAiToken, OpenAiURL, + GeminiToken, + GeminiURL, SessionID, Theme, ThemeFile, @@ -82,6 +84,8 @@ impl Config { ConfigKey::OllamaURL => "http://localhost:11434", ConfigKey::OpenAiToken => "", ConfigKey::OpenAiURL => "https://api.openai.com", + ConfigKey::GeminiToken => "", + ConfigKey::GeminiURL => "https://generativelanguage.googleapis.com", ConfigKey::Theme => "base16-onedark", ConfigKey::ThemeFile => "", diff --git a/src/domain/models/backend.rs b/src/domain/models/backend.rs index 1f7b0f6..6c4fc2f 100644 --- a/src/domain/models/backend.rs +++ b/src/domain/models/backend.rs @@ -19,6 +19,7 @@ pub enum BackendName { LangChain, Ollama, OpenAI, + Gemini, } impl BackendName { From 2053e058cc017a56deef48dcc973891c28746c0d Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:45:02 -0600 Subject: [PATCH 02/16] feat: Add cli arguments for gemini --- src/application/cli.rs | 118 +++++++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 53 deletions(-) diff --git a/src/application/cli.rs b/src/application/cli.rs index 1228422..dcebed3 100644 --- a/src/application/cli.rs +++ b/src/application/cli.rs @@ -397,6 +397,22 @@ pub fn build() -> Command { .num_args(1) .help("OpenAI API token when using the OpenAI backend.") .global(true), + ) + .arg( + Arg::new(ConfigKey::GeminiURL.to_string()) + .long(ConfigKey::GeminiURL.to_string()) + .env("OATMEAL_GEMINI_URL") + .num_args(1) + .help(format!("Google Gemini API URL when using the Gemini backend. [default: {}]", Config::default(ConfigKey::OpenAiURL))) + .global(true), + ) + .arg( + Arg::new(ConfigKey::GeminiToken.to_string()) + .long(ConfigKey::GeminiToken.to_string()) + .env("OATMEAL_GEMINI_TOKEN") + .num_args(1) + .help("Google Gemini API token when using the Gemini backend.") + .global(true), ); } @@ -441,67 +457,63 @@ pub async fn parse() -> Result { print_completions(completions, &mut app); } } - Some(("config", subcmd_matches)) => { - match subcmd_matches.subcommand() { - Some(("create", _)) => { - create_config_file().await?; - return Ok(false); - } - Some(("default", _)) => { - println!("{}", Config::serialize_default(build())); - return Ok(false); - } - Some(("path", _)) => { - println!("{}", Config::default(ConfigKey::ConfigFile)); - return Ok(false); - } - _ => { - subcommand_config().print_long_help()?; - return Ok(false); - } + Some(("config", subcmd_matches)) => match subcmd_matches.subcommand() { + Some(("create", _)) => { + create_config_file().await?; + return Ok(false); } - } + Some(("default", _)) => { + println!("{}", Config::serialize_default(build())); + return Ok(false); + } + Some(("path", _)) => { + println!("{}", Config::default(ConfigKey::ConfigFile)); + return Ok(false); + } + _ => { + subcommand_config().print_long_help()?; + return Ok(false); + } + }, Some(("manpages", _)) => { clap_mangen::Man::new(build()).render(&mut io::stdout())?; return Ok(false); } - Some(("sessions", subcmd_matches)) => { - match subcmd_matches.subcommand() { - Some(("dir", _)) => { - let dir = Sessions::default().cache_dir.to_string_lossy().to_string(); - println!("{dir}"); - return Ok(false); - } - Some(("list", _)) => { - print_sessions_list().await?; - return Ok(false); - } - Some(("open", open_matches)) => { - Config::load(build(), vec![&matches, open_matches]).await?; - if let Some(session_id) = open_matches.get_one::("session-id") { - load_config_from_session(session_id).await?; - } else { - load_config_from_session_interactive().await?; - } - } - Some(("delete", delete_matches)) => { - if let Some(session_id) = delete_matches.get_one::("session-id") { - Sessions::default().delete(session_id).await?; - println!("Deleted session {session_id}"); - } else if delete_matches.get_one::("all").is_some() { - Sessions::default().delete_all().await?; - println!("Deleted all sessions"); - } else { - subcommand_sessions_delete().print_long_help()?; - } - return Ok(false); + Some(("sessions", subcmd_matches)) => match subcmd_matches.subcommand() { + Some(("dir", _)) => { + let dir = Sessions::default().cache_dir.to_string_lossy().to_string(); + println!("{dir}"); + return Ok(false); + } + Some(("list", _)) => { + print_sessions_list().await?; + return Ok(false); + } + Some(("open", open_matches)) => { + Config::load(build(), vec![&matches, open_matches]).await?; + if let Some(session_id) = open_matches.get_one::("session-id") { + load_config_from_session(session_id).await?; + } else { + load_config_from_session_interactive().await?; } - _ => { - subcommand_sessions().print_long_help()?; - return Ok(false); + } + Some(("delete", delete_matches)) => { + if let Some(session_id) = delete_matches.get_one::("session-id") { + Sessions::default().delete(session_id).await?; + println!("Deleted session {session_id}"); + } else if delete_matches.get_one::("all").is_some() { + Sessions::default().delete_all().await?; + println!("Deleted all sessions"); + } else { + subcommand_sessions_delete().print_long_help()?; } + return Ok(false); } - } + _ => { + subcommand_sessions().print_long_help()?; + return Ok(false); + } + }, _ => { Config::load(build(), vec![&matches]).await?; } From 132064bfdbe893332734595430d333081e35232d Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 26 Feb 2024 10:22:34 -0600 Subject: [PATCH 03/16] feat: Add first version gemini --- src/infrastructure/backends/gemini.rs | 252 ++++++++++++++++++++++++++ src/infrastructure/backends/mod.rs | 5 + 2 files changed, 257 insertions(+) create mode 100644 src/infrastructure/backends/gemini.rs diff --git a/src/infrastructure/backends/gemini.rs b/src/infrastructure/backends/gemini.rs new file mode 100644 index 0000000..9a04053 --- /dev/null +++ b/src/infrastructure/backends/gemini.rs @@ -0,0 +1,252 @@ +#[cfg(test)] +#[path = "gemini_test.rs"] +mod tests; + +use std::time::Duration; + +use anyhow::bail; +use anyhow::Result; +use async_trait::async_trait; +use futures::stream::TryStreamExt; +use serde::Deserialize; +use serde::Serialize; +use tokio::io::AsyncBufReadExt; +use tokio::sync::mpsc; +use tokio_util::io::StreamReader; + +use crate::configuration::Config; +use crate::configuration::ConfigKey; +use crate::domain::models::Author; +use crate::domain::models::Backend; +use crate::domain::models::BackendName; +use crate::domain::models::BackendPrompt; +use crate::domain::models::BackendResponse; +use crate::domain::models::Event; + +fn convert_err(err: reqwest::Error) -> std::io::Error { + let err_msg = err.to_string(); + return std::io::Error::new(std::io::ErrorKind::Interrupted, err_msg); +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct Model { + name: String, + version: String, + display_name: String, + description: String, + input_token_limit: u32, + output_token_limit: u32, + supported_generation_methods: Vec, + // temperature: d32, + // top_p: f32, + // top_k: u32, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct ModelListResponse { + models: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ContentPartsBlob { + mime_type: String, + data: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +enum ContentParts { + Text(String), + InlineData(ContentPartsBlob), +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct Content { + role: String, + parts: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct CompletionRequest { + contents: Vec, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GenerateContentResponse { + text: String, +} + +pub struct Gemini { + url: String, + token: String, + timeout: String, +} + +impl Default for Gemini { + fn default() -> Gemini { + return Gemini { + url: Config::get(ConfigKey::GeminiURL), + token: Config::get(ConfigKey::GeminiToken), + timeout: Config::get(ConfigKey::BackendHealthCheckTimeout), + }; + } +} + +#[async_trait] +impl Backend for Gemini { + fn name(&self) -> BackendName { + return BackendName::Gemini; + } + + #[allow(clippy::implicit_return)] + async fn health_check(&self) -> Result<()> { + if self.url.is_empty() { + bail!("Gemini URL is not defined"); + } + if self.token.is_empty() { + bail!("Gemini token is not defined"); + } + + let url = format!( + "{url}/v1beta/{model}?key={key}", + url = self.url, + model = Config::get(ConfigKey::Model), + key = self.token + ); + + let res = reqwest::Client::new() + .get(&url) + .timeout(Duration::from_millis(self.timeout.parse::()?)) + .send() + .await; + + if res.is_err() { + tracing::error!(error = ?res.unwrap_err(), "Gemini is not reachable"); + bail!("Gemini is not reachable"); + } + + let status = res.unwrap().status().as_u16(); + if status >= 400 { + tracing::error!(status = status, "Gemini health check failed"); + bail!("Gemini health check failed"); + } + + return Ok(()); + } + + #[allow(clippy::implicit_return)] + async fn list_models(&self) -> Result> { + let res = reqwest::Client::new() + .get(format!( + "{url}/v1beta/models?key={key}", + url = self.url, + key = self.token + )) + .send() + .await? + .json::() + .await?; + + let mut models: Vec = res + .models + .iter() + .filter(|model| { + model + .supported_generation_methods + .contains(&"generateContent".to_string()) + }) + .map(|model| { + return model.name.to_string(); + }) + .collect(); + + models.sort(); + + return Ok(models); + } + + #[allow(clippy::implicit_return)] + async fn get_completion<'a>( + &self, + prompt: BackendPrompt, + tx: &'a mpsc::UnboundedSender, + ) -> Result<()> { + let mut contents: Vec = vec![]; + if !prompt.backend_context.is_empty() { + contents = serde_json::from_str(&prompt.backend_context)?; + } + contents.push(Content { + role: "user".to_string(), + parts: vec![ContentParts::Text(prompt.text)], + }); + + let req = CompletionRequest { + contents: contents.clone(), + }; + + let res = reqwest::Client::new() + .post(format!( + "{url}/v1beta/{model}:streamGenerateContent?key={key}", + url = self.url, + model = Config::get(ConfigKey::Model), + key = self.token, + )) + .json(&req) + .send() + .await?; + + if !res.status().is_success() { + tracing::error!( + status = res.status().as_u16(), + "Failed to make completion request to Gemini" + ); + bail!(format!( + "Failed to make completion request to Gemini, {}", + res.status().as_u16() + )); + } + let stream = res.bytes_stream().map_err(convert_err); + let mut lines_reader = StreamReader::new(stream).lines(); + + let mut last_message = "".to_string(); + while let Ok(line) = lines_reader.next_line().await { + if line.is_none() { + break; + } + + let cleaned_line = line.unwrap().trim().to_string(); + if !cleaned_line.starts_with("\"text\":") { + continue; + } + + let ores: GenerateContentResponse = + serde_json::from_str(&format!("{{ {text} }}", text = cleaned_line)).unwrap(); + last_message += &ores.text; + let msg = BackendResponse { + author: Author::Model, + text: ores.text, + done: false, + context: None, + }; + tx.send(Event::BackendPromptResponse(msg))?; + } + + contents.push(Content { + role: "model".to_string(), + parts: vec![ContentParts::Text(last_message.clone())], + }); + + let msg = BackendResponse { + author: Author::Model, + text: "".to_string(), + done: true, + context: Some(serde_json::to_string(&contents)?), + }; + tx.send(Event::BackendPromptResponse(msg))?; + + return Ok(()); + } +} diff --git a/src/infrastructure/backends/mod.rs b/src/infrastructure/backends/mod.rs index 8f992e8..ee2fb68 100644 --- a/src/infrastructure/backends/mod.rs +++ b/src/infrastructure/backends/mod.rs @@ -1,3 +1,4 @@ +pub mod gemini; pub mod langchain; pub mod ollama; pub mod openai; @@ -23,6 +24,10 @@ impl BackendManager { return Ok(Box::::default()); } + if name == BackendName::Gemini { + return Ok(Box::::default()); + } + bail!(format!("No backend implemented for {name}")) } } From f3aa847bc7d002d8d5744a78fa9d5b2d58155d85 Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 26 Feb 2024 11:02:38 -0600 Subject: [PATCH 04/16] fix: Add break on empty text --- src/infrastructure/backends/gemini.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/infrastructure/backends/gemini.rs b/src/infrastructure/backends/gemini.rs index 9a04053..8d6d81a 100644 --- a/src/infrastructure/backends/gemini.rs +++ b/src/infrastructure/backends/gemini.rs @@ -224,6 +224,11 @@ impl Backend for Gemini { let ores: GenerateContentResponse = serde_json::from_str(&format!("{{ {text} }}", text = cleaned_line)).unwrap(); + + if ores.text.is_empty() || ores.text == "" || ores.text == "\n" { + break; + } + last_message += &ores.text; let msg = BackendResponse { author: Author::Model, From aa96b1aba5c3648ab6db3661583684c6d294c417 Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 26 Feb 2024 11:55:29 -0600 Subject: [PATCH 05/16] feat: Added tests and snapshots --- config.example.toml | 8 +- src/infrastructure/backends/gemini_test.rs | 195 ++++++++++++++++++ ...g__tests__it_serializes_to_valid_toml.snap | 8 +- ...s__gemini__tests__it_gets_completions.snap | 5 + 4 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 src/infrastructure/backends/gemini_test.rs create mode 100644 test/snapshots/oatmeal__infrastructure__backends__gemini__tests__it_gets_completions.snap diff --git a/config.example.toml b/config.example.toml index 9939e60..12de7bb 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,4 +1,4 @@ -# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai] +# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai, gemini] backend = "ollama" # Time to wait in milliseconds before timing out when doing a healthcheck for a backend. @@ -22,6 +22,12 @@ ollama-url = "http://localhost:11434" # OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. open-ai-url = "https://api.openai.com" +# OpenAI API token when using the OpenAI backend. +# gemini-token = "" + +# OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. +gemini-url = "https://generativelanguage.googleapis.com" + # Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti] theme = "base16-onedark" diff --git a/src/infrastructure/backends/gemini_test.rs b/src/infrastructure/backends/gemini_test.rs new file mode 100644 index 0000000..8d5cabf --- /dev/null +++ b/src/infrastructure/backends/gemini_test.rs @@ -0,0 +1,195 @@ +use anyhow::bail; +use anyhow::Result; +use test_utils::insta_snapshot; +use tokio::sync::mpsc; + +use super::Config; +use super::Content; +use super::ContentParts; +use super::Gemini; +use super::Model; +use super::ModelListResponse; +use crate::configuration::ConfigKey; +use crate::domain::models::Author; +use crate::domain::models::Backend; +use crate::domain::models::BackendPrompt; +use crate::domain::models::BackendResponse; +use crate::domain::models::Event; + +impl Gemini { + fn with_url(url: String) -> Gemini { + return Gemini { + url, + token: "abc".to_string(), + timeout: "200".to_string(), + }; + } +} + +fn to_res(action: Option) -> Result { + let act = match action.unwrap() { + Event::BackendPromptResponse(res) => res, + _ => bail!("Wrong type from recv"), + }; + + return Ok(act); +} + +#[tokio::test] +async fn it_successfully_health_checks() { + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/v1beta/model-1?key=abc") + .with_status(200) + .create(); + + let backend = Gemini::with_url(server.url()); + let res = backend.health_check().await; + + assert!(res.is_ok()); + mock.assert(); +} + +#[tokio::test] +async fn it_successfully_health_checks_with_official_api() { + Config::set(ConfigKey::Model, "models/gemini-pro"); + let token = match std::env::var("OATMEAL_GEMINI_TOKEN") { + Ok(token) => token, + Err(_) => { + println!("There is no token in environment defined, skipping test"); + return; + } + }; + let backend = Gemini { + url: "https://generativelanguage.googleapis.com".to_string(), + token, + timeout: "500".to_string(), + }; + + let res = backend.health_check().await; + assert!(res.is_ok()); +} + +#[tokio::test] +async fn it_fails_health_checks() { + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/v1beta/model-1?key=abc") + .with_status(500) + .create(); + + let backend = Gemini::with_url(server.url()); + let res = backend.health_check().await; + + assert!(res.is_err()); + mock.assert(); +} + +#[tokio::test] +async fn it_lists_models() -> Result<()> { + let body = serde_json::to_string(&ModelListResponse { + models: vec![ + Model { + name: "first".to_string(), + description: "First model".to_string(), + display_name: "First model".to_string(), + input_token_limit: 2048, + version: "1.0".to_string(), + supported_generation_methods: vec!["generateContent".to_string()], + output_token_limit: 2048, + }, + Model { + name: "second".to_string(), + description: "Second model".to_string(), + display_name: "Second model".to_string(), + input_token_limit: 2048, + version: "1.0".to_string(), + supported_generation_methods: vec!["generateContent".to_string()], + output_token_limit: 2048, + }, + ], + })?; + + let mut server = mockito::Server::new(); + let mock = server + .mock("GET", "/v1beta/models?key=abc") + .with_status(200) + .with_body(body) + .create(); + + let backend = Gemini::with_url(server.url()); + let res = backend.list_models().await?; + mock.assert(); + + assert_eq!(res, vec!["first".to_string(), "second".to_string()]); + + return Ok(()); +} + +#[tokio::test] +async fn it_gets_completions() -> Result<()> { + let body = [ + "[", + "\"contents\": [{", + "\"parts\": [{", + "\"text\": \"Hello \"", + "}]", + "},", + "{", + "\"parts\": [{", + "\"text\": \"World\"", + "}]", + "},", + "{", + "\"parts\": [{", + "\"text\": \"\"", + "}]", + "}]", + "]", + ] + .join("\n"); + let prompt = BackendPrompt { + text: "Say hi to the world".to_string(), + backend_context: serde_json::to_string(&vec![Content { + role: "model".to_string(), + parts: vec![ContentParts::Text("Hello".to_string())], + }])?, + }; + + let mut server = mockito::Server::new(); + let mock = server + .mock("POST", "/v1beta/model-1:streamGenerateContent?key=abc") + .with_status(200) + .with_body(body) + .create(); + + let (tx, mut rx) = mpsc::unbounded_channel::(); + + let backend = Gemini::with_url(server.url()); + backend.get_completion(prompt, &tx).await?; + + mock.assert(); + + let first_recv = to_res(rx.recv().await)?; + let second_recv = to_res(rx.recv().await)?; + let third_recv = to_res(rx.recv().await)?; + + assert_eq!(first_recv.author, Author::Model); + assert_eq!(first_recv.text, "Hello ".to_string()); + assert!(!first_recv.done); + assert_eq!(first_recv.context, None); + + assert_eq!(second_recv.author, Author::Model); + assert_eq!(second_recv.text, "World".to_string()); + assert!(!second_recv.done); + assert_eq!(second_recv.context, None); + + assert_eq!(third_recv.author, Author::Model); + assert_eq!(third_recv.text, "".to_string()); + assert!(third_recv.done); + insta_snapshot(|| { + insta::assert_toml_snapshot!(third_recv.context); + }); + + return Ok(()); +} diff --git a/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap b/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap index 37fa6b8..18c6338 100644 --- a/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap +++ b/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap @@ -3,7 +3,7 @@ source: src/configuration/config_test.rs expression: res --- ''' -# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai] +# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai, gemini] backend = "ollama" # Time to wait in milliseconds before timing out when doing a healthcheck for a backend. @@ -27,6 +27,12 @@ ollama-url = "http://localhost:11434" # OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. open-ai-url = "https://api.openai.com" +# Google Gemini API token when using the Gemini backend. +# gemini-token = "" + +# Google Gemini API URL when using the Gemini backend. +gemini-url = "https://generativelanguage.googleapis.com" + # Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti] theme = "base16-onedark" diff --git a/test/snapshots/oatmeal__infrastructure__backends__gemini__tests__it_gets_completions.snap b/test/snapshots/oatmeal__infrastructure__backends__gemini__tests__it_gets_completions.snap new file mode 100644 index 0000000..78aa7ac --- /dev/null +++ b/test/snapshots/oatmeal__infrastructure__backends__gemini__tests__it_gets_completions.snap @@ -0,0 +1,5 @@ +--- +source: src/infrastructure/backends/gemini_test.rs +expression: third_recv.context +--- +'[{"role":"model","parts":[{"text":"Hello"}]},{"role":"user","parts":[{"text":"Say hi to the world"}]},{"role":"model","parts":[{"text":"Hello World"}]}]' From da3260522e1cf27874763b9a6cfc23ccf93dd63e Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 26 Feb 2024 12:00:59 -0600 Subject: [PATCH 06/16] feat: Update readme --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c4a0ebc..4ef13b0 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,7 @@ Commands: Options: -b, --backend - The initial backend hosting a model to connect to. [default: ollama] [env: OATMEAL_BACKEND=] [possible values: langchain, ollama, openai] + The initial backend hosting a model to connect to. [default: ollama] [env: OATMEAL_BACKEND=] [possible values: langchain, ollama, openai, gemini] --backend-health-check-timeout Time to wait in milliseconds before timing out when doing a healthcheck for a backend. [default: 1000] [env: OATMEAL_BACKEND_HEALTH_CHECK_TIMEOUT=] -m, --model @@ -194,6 +194,10 @@ Options: OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. [default: https://api.openai.com] [env: OATMEAL_OPENAI_URL=] --open-ai-token OpenAI API token when using the OpenAI backend. [env: OATMEAL_OPENAI_TOKEN=] + --gemini-url + Gemini API URL when using the Gemini backend. [default: https://generativelanguage.googleapis.com] [env: OATMEAL_GEMINI_URL=] + --gemini-token + Gemini API token when using the Gemini backend. [env: OATMEAL_GEMINI_TOKEN=] -h, --help Print help -V, --version @@ -261,6 +265,7 @@ The following model backends are supported: - [OpenAI](https://chat.openai.com) (Or any compatible proxy/API) - [Ollama](https://github.com/jmorganca/ollama) - [LangChain/LangServe](https://python.langchain.com/docs/langserve) (Experimental) +- [Gemini](https://gemini.google.com) (Experimental) ### Editors From 356ccc8c1c9d9f941dbfc5502e722a27f3ece1fa Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 26 Feb 2024 12:11:13 -0600 Subject: [PATCH 07/16] fix: Commented out unused model attributes --- src/infrastructure/backends/gemini.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/infrastructure/backends/gemini.rs b/src/infrastructure/backends/gemini.rs index 8d6d81a..5b3d180 100644 --- a/src/infrastructure/backends/gemini.rs +++ b/src/infrastructure/backends/gemini.rs @@ -32,14 +32,14 @@ fn convert_err(err: reqwest::Error) -> std::io::Error { #[serde(rename_all = "camelCase")] struct Model { name: String, - version: String, - display_name: String, - description: String, - input_token_limit: u32, - output_token_limit: u32, + // version: String, + // display_name: String, + // description: String, + // input_token_limit: u32, + // output_token_limit: u32, supported_generation_methods: Vec, - // temperature: d32, - // top_p: f32, + // temperature: f64, + // top_p: f64, // top_k: u32, } From 8e23fd4d1ca07c4f177e6d2d5af76cd47e01db37 Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Tue, 27 Feb 2024 15:58:17 -0600 Subject: [PATCH 08/16] fix: Config URL for Gemini cannot be different --- src/application/cli.rs | 8 -------- src/configuration/config.rs | 2 -- src/infrastructure/backends/gemini.rs | 2 +- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/application/cli.rs b/src/application/cli.rs index dcebed3..5566ce4 100644 --- a/src/application/cli.rs +++ b/src/application/cli.rs @@ -398,14 +398,6 @@ pub fn build() -> Command { .help("OpenAI API token when using the OpenAI backend.") .global(true), ) - .arg( - Arg::new(ConfigKey::GeminiURL.to_string()) - .long(ConfigKey::GeminiURL.to_string()) - .env("OATMEAL_GEMINI_URL") - .num_args(1) - .help(format!("Google Gemini API URL when using the Gemini backend. [default: {}]", Config::default(ConfigKey::OpenAiURL))) - .global(true), - ) .arg( Arg::new(ConfigKey::GeminiToken.to_string()) .long(ConfigKey::GeminiToken.to_string()) diff --git a/src/configuration/config.rs b/src/configuration/config.rs index 431563c..00ee2b1 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -34,7 +34,6 @@ pub enum ConfigKey { OpenAiToken, OpenAiURL, GeminiToken, - GeminiURL, SessionID, Theme, ThemeFile, @@ -85,7 +84,6 @@ impl Config { ConfigKey::OpenAiToken => "", ConfigKey::OpenAiURL => "https://api.openai.com", ConfigKey::GeminiToken => "", - ConfigKey::GeminiURL => "https://generativelanguage.googleapis.com", ConfigKey::Theme => "base16-onedark", ConfigKey::ThemeFile => "", diff --git a/src/infrastructure/backends/gemini.rs b/src/infrastructure/backends/gemini.rs index 5b3d180..108950a 100644 --- a/src/infrastructure/backends/gemini.rs +++ b/src/infrastructure/backends/gemini.rs @@ -88,7 +88,7 @@ pub struct Gemini { impl Default for Gemini { fn default() -> Gemini { return Gemini { - url: Config::get(ConfigKey::GeminiURL), + url: "https://generativelanguage.googleapis.com".to_string(), token: Config::get(ConfigKey::GeminiToken), timeout: Config::get(ConfigKey::BackendHealthCheckTimeout), }; From b805eb26ceb017d4d00454814e18b814d713a5ec Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Wed, 28 Feb 2024 09:54:34 -0600 Subject: [PATCH 09/16] docs: Removed gemini-url from docs --- README.md | 2 -- config.example.toml | 5 +---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/README.md b/README.md index 4ef13b0..f16b229 100644 --- a/README.md +++ b/README.md @@ -194,8 +194,6 @@ Options: OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. [default: https://api.openai.com] [env: OATMEAL_OPENAI_URL=] --open-ai-token OpenAI API token when using the OpenAI backend. [env: OATMEAL_OPENAI_TOKEN=] - --gemini-url - Gemini API URL when using the Gemini backend. [default: https://generativelanguage.googleapis.com] [env: OATMEAL_GEMINI_URL=] --gemini-token Gemini API token when using the Gemini backend. [env: OATMEAL_GEMINI_TOKEN=] -h, --help diff --git a/config.example.toml b/config.example.toml index 12de7bb..376227e 100644 --- a/config.example.toml +++ b/config.example.toml @@ -22,12 +22,9 @@ ollama-url = "http://localhost:11434" # OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. open-ai-url = "https://api.openai.com" -# OpenAI API token when using the OpenAI backend. +# Gemini API token when using the OpenAI backend. # gemini-token = "" -# OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. -gemini-url = "https://generativelanguage.googleapis.com" - # Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti] theme = "base16-onedark" From 07a42fee4ed4c5855fc2ad3ede887562eb84953f Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Sat, 2 Mar 2024 17:51:59 -0600 Subject: [PATCH 10/16] test: Removed unused vars in model from tests --- src/infrastructure/backends/gemini_test.rs | 21 ++++++++++--------- ...g__tests__it_serializes_to_valid_toml.snap | 3 --- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/infrastructure/backends/gemini_test.rs b/src/infrastructure/backends/gemini_test.rs index 8d5cabf..412708b 100644 --- a/src/infrastructure/backends/gemini_test.rs +++ b/src/infrastructure/backends/gemini_test.rs @@ -1,5 +1,6 @@ use anyhow::bail; use anyhow::Result; +use ratatui::backend; use test_utils::insta_snapshot; use tokio::sync::mpsc; @@ -91,21 +92,21 @@ async fn it_lists_models() -> Result<()> { models: vec![ Model { name: "first".to_string(), - description: "First model".to_string(), - display_name: "First model".to_string(), - input_token_limit: 2048, - version: "1.0".to_string(), + // description: "First model".to_string(), + // display_name: "First model".to_string(), + // input_token_limit: 2048, + // version: "1.0".to_string(), supported_generation_methods: vec!["generateContent".to_string()], - output_token_limit: 2048, + // output_token_limit: 2048, }, Model { name: "second".to_string(), - description: "Second model".to_string(), - display_name: "Second model".to_string(), - input_token_limit: 2048, - version: "1.0".to_string(), + // description: "Second model".to_string(), + // display_name: "Second model".to_string(), + // input_token_limit: 2048, + // version: "1.0".to_string(), supported_generation_methods: vec!["generateContent".to_string()], - output_token_limit: 2048, + // output_token_limit: 2048, }, ], })?; diff --git a/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap b/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap index 18c6338..8746b7f 100644 --- a/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap +++ b/test/snapshots/oatmeal__configuration__config__tests__it_serializes_to_valid_toml.snap @@ -30,9 +30,6 @@ open-ai-url = "https://api.openai.com" # Google Gemini API token when using the Gemini backend. # gemini-token = "" -# Google Gemini API URL when using the Gemini backend. -gemini-url = "https://generativelanguage.googleapis.com" - # Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti] theme = "base16-onedark" From 5f4fd4f89b0634060e456ef9b3e4144167ddfeee Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Sun, 3 Mar 2024 02:11:56 -0600 Subject: [PATCH 11/16] refactor: Clean imports in test --- src/infrastructure/backends/gemini_test.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/infrastructure/backends/gemini_test.rs b/src/infrastructure/backends/gemini_test.rs index 412708b..efc5f14 100644 --- a/src/infrastructure/backends/gemini_test.rs +++ b/src/infrastructure/backends/gemini_test.rs @@ -1,6 +1,5 @@ use anyhow::bail; use anyhow::Result; -use ratatui::backend; use test_utils::insta_snapshot; use tokio::sync::mpsc; From 591290524c016237c28dc9161c74542cc10f30ee Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:15:37 -0600 Subject: [PATCH 12/16] docs: Fix config example, removed unused variables --- config.example.toml | 2 +- src/infrastructure/backends/gemini.rs | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/config.example.toml b/config.example.toml index 376227e..f79a209 100644 --- a/config.example.toml +++ b/config.example.toml @@ -22,7 +22,7 @@ ollama-url = "http://localhost:11434" # OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. open-ai-url = "https://api.openai.com" -# Gemini API token when using the OpenAI backend. +# Gemini API token when using the Gemini backend. # gemini-token = "" # Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti] diff --git a/src/infrastructure/backends/gemini.rs b/src/infrastructure/backends/gemini.rs index 108950a..05de532 100644 --- a/src/infrastructure/backends/gemini.rs +++ b/src/infrastructure/backends/gemini.rs @@ -32,15 +32,7 @@ fn convert_err(err: reqwest::Error) -> std::io::Error { #[serde(rename_all = "camelCase")] struct Model { name: String, - // version: String, - // display_name: String, - // description: String, - // input_token_limit: u32, - // output_token_limit: u32, supported_generation_methods: Vec, - // temperature: f64, - // top_p: f64, - // top_k: u32, } #[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] From c2e55400c5df383cd27bbab85ec0851afe06662b Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:16:04 -0600 Subject: [PATCH 13/16] fix: Add config set to gemini tests --- src/infrastructure/backends/gemini_test.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/infrastructure/backends/gemini_test.rs b/src/infrastructure/backends/gemini_test.rs index efc5f14..c29e058 100644 --- a/src/infrastructure/backends/gemini_test.rs +++ b/src/infrastructure/backends/gemini_test.rs @@ -37,6 +37,7 @@ fn to_res(action: Option) -> Result { #[tokio::test] async fn it_successfully_health_checks() { + Config::set(ConfigKey::Model, "model-1"); let mut server = mockito::Server::new(); let mock = server .mock("GET", "/v1beta/model-1?key=abc") @@ -72,6 +73,7 @@ async fn it_successfully_health_checks_with_official_api() { #[tokio::test] async fn it_fails_health_checks() { + Config::set(ConfigKey::Model, "model-1"); let mut server = mockito::Server::new(); let mock = server .mock("GET", "/v1beta/model-1?key=abc") @@ -91,21 +93,11 @@ async fn it_lists_models() -> Result<()> { models: vec![ Model { name: "first".to_string(), - // description: "First model".to_string(), - // display_name: "First model".to_string(), - // input_token_limit: 2048, - // version: "1.0".to_string(), supported_generation_methods: vec!["generateContent".to_string()], - // output_token_limit: 2048, }, Model { name: "second".to_string(), - // description: "Second model".to_string(), - // display_name: "Second model".to_string(), - // input_token_limit: 2048, - // version: "1.0".to_string(), supported_generation_methods: vec!["generateContent".to_string()], - // output_token_limit: 2048, }, ], })?; From 1b635504f8ff725be7ba483b68315583f106ed82 Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:24:21 -0600 Subject: [PATCH 14/16] fix: Added config set to gemini test completions --- src/infrastructure/backends/gemini_test.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/infrastructure/backends/gemini_test.rs b/src/infrastructure/backends/gemini_test.rs index c29e058..364ea41 100644 --- a/src/infrastructure/backends/gemini_test.rs +++ b/src/infrastructure/backends/gemini_test.rs @@ -120,6 +120,7 @@ async fn it_lists_models() -> Result<()> { #[tokio::test] async fn it_gets_completions() -> Result<()> { + Config::set(ConfigKey::Model, "model-1"); let body = [ "[", "\"contents\": [{", From bd82547f5a2c5920de7305fc03e548e002da5368 Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Sun, 10 Mar 2024 20:17:10 -0600 Subject: [PATCH 15/16] refactor: Change completion test body to raw string --- src/infrastructure/backends/gemini_test.rs | 47 +++++++++++++--------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/src/infrastructure/backends/gemini_test.rs b/src/infrastructure/backends/gemini_test.rs index 364ea41..afc6ee6 100644 --- a/src/infrastructure/backends/gemini_test.rs +++ b/src/infrastructure/backends/gemini_test.rs @@ -121,26 +121,33 @@ async fn it_lists_models() -> Result<()> { #[tokio::test] async fn it_gets_completions() -> Result<()> { Config::set(ConfigKey::Model, "model-1"); - let body = [ - "[", - "\"contents\": [{", - "\"parts\": [{", - "\"text\": \"Hello \"", - "}]", - "},", - "{", - "\"parts\": [{", - "\"text\": \"World\"", - "}]", - "},", - "{", - "\"parts\": [{", - "\"text\": \"\"", - "}]", - "}]", - "]", - ] - .join("\n"); + let body = r#" +{ + "contents": [ + { + "parts": [ + { + "text": "Hello " + } + ] + }, + { + "parts": [ + { + "text": "World" + } + ] + }, + { + "parts": [ + { + "text": "" + } + ] + } + ] +} + "#; let prompt = BackendPrompt { text: "Say hi to the world".to_string(), backend_context: serde_json::to_string(&vec![Content { From a2b3a83e36bd16c69017911cf5d58bdd2177df5c Mon Sep 17 00:00:00 2001 From: Andres <9813380+aislasq@users.noreply.github.com> Date: Mon, 11 Mar 2024 00:07:55 -0600 Subject: [PATCH 16/16] chore: Run fmt lint --- src/application/cli.rs | 102 +++++++++++++------------- src/infrastructure/backends/gemini.rs | 2 +- 2 files changed, 54 insertions(+), 50 deletions(-) diff --git a/src/application/cli.rs b/src/application/cli.rs index 19d4df7..1414f2b 100644 --- a/src/application/cli.rs +++ b/src/application/cli.rs @@ -450,63 +450,67 @@ pub async fn parse() -> Result { print_completions(completions, &mut app); } } - Some(("config", subcmd_matches)) => match subcmd_matches.subcommand() { - Some(("create", _)) => { - create_config_file().await?; - return Ok(false); - } - Some(("default", _)) => { - println!("{}", Config::serialize_default(build())); - return Ok(false); - } - Some(("path", _)) => { - println!("{}", Config::default(ConfigKey::ConfigFile)); - return Ok(false); - } - _ => { - subcommand_config().print_long_help()?; - return Ok(false); + Some(("config", subcmd_matches)) => { + match subcmd_matches.subcommand() { + Some(("create", _)) => { + create_config_file().await?; + return Ok(false); + } + Some(("default", _)) => { + println!("{}", Config::serialize_default(build())); + return Ok(false); + } + Some(("path", _)) => { + println!("{}", Config::default(ConfigKey::ConfigFile)); + return Ok(false); + } + _ => { + subcommand_config().print_long_help()?; + return Ok(false); + } } - }, + } Some(("manpages", _)) => { clap_mangen::Man::new(build()).render(&mut io::stdout())?; return Ok(false); } - Some(("sessions", subcmd_matches)) => match subcmd_matches.subcommand() { - Some(("dir", _)) => { - let dir = Sessions::default().cache_dir.to_string_lossy().to_string(); - println!("{dir}"); - return Ok(false); - } - Some(("list", _)) => { - print_sessions_list().await?; - return Ok(false); - } - Some(("open", open_matches)) => { - Config::load(build(), vec![&matches, open_matches]).await?; - if let Some(session_id) = open_matches.get_one::("session-id") { - load_config_from_session(session_id).await?; - } else { - load_config_from_session_interactive().await?; + Some(("sessions", subcmd_matches)) => { + match subcmd_matches.subcommand() { + Some(("dir", _)) => { + let dir = Sessions::default().cache_dir.to_string_lossy().to_string(); + println!("{dir}"); + return Ok(false); } - } - Some(("delete", delete_matches)) => { - if let Some(session_id) = delete_matches.get_one::("session-id") { - Sessions::default().delete(session_id).await?; - println!("Deleted session {session_id}"); - } else if delete_matches.get_one::("all").is_some() { - Sessions::default().delete_all().await?; - println!("Deleted all sessions"); - } else { - subcommand_sessions_delete().print_long_help()?; + Some(("list", _)) => { + print_sessions_list().await?; + return Ok(false); + } + Some(("open", open_matches)) => { + Config::load(build(), vec![&matches, open_matches]).await?; + if let Some(session_id) = open_matches.get_one::("session-id") { + load_config_from_session(session_id).await?; + } else { + load_config_from_session_interactive().await?; + } + } + Some(("delete", delete_matches)) => { + if let Some(session_id) = delete_matches.get_one::("session-id") { + Sessions::default().delete(session_id).await?; + println!("Deleted session {session_id}"); + } else if delete_matches.get_one::("all").is_some() { + Sessions::default().delete_all().await?; + println!("Deleted all sessions"); + } else { + subcommand_sessions_delete().print_long_help()?; + } + return Ok(false); + } + _ => { + subcommand_sessions().print_long_help()?; + return Ok(false); } - return Ok(false); - } - _ => { - subcommand_sessions().print_long_help()?; - return Ok(false); } - }, + } _ => { Config::load(build(), vec![&matches]).await?; } diff --git a/src/infrastructure/backends/gemini.rs b/src/infrastructure/backends/gemini.rs index 05de532..166bc25 100644 --- a/src/infrastructure/backends/gemini.rs +++ b/src/infrastructure/backends/gemini.rs @@ -217,7 +217,7 @@ impl Backend for Gemini { let ores: GenerateContentResponse = serde_json::from_str(&format!("{{ {text} }}", text = cleaned_line)).unwrap(); - if ores.text.is_empty() || ores.text == "" || ores.text == "\n" { + if ores.text.is_empty() || ores.text.is_empty() || ores.text == "\n" { break; }