Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Add gemini backend #52

Merged
merged 18 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ Commands:

Options:
-b, --backend <backend>
The initial backend hosting a model to connect to. [default: ollama] [env: OATMEAL_BACKEND=] [possible values: langchain, ollama, openai]
The initial backend hosting a model to connect to. [default: ollama] [env: OATMEAL_BACKEND=] [possible values: langchain, ollama, openai, gemini]
--backend-health-check-timeout <backend-health-check-timeout>
Time to wait in milliseconds before timing out when doing a healthcheck for a backend. [default: 1000] [env: OATMEAL_BACKEND_HEALTH_CHECK_TIMEOUT=]
-m, --model <model>
Expand All @@ -194,6 +194,8 @@ Options:
OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy. [default: https://api.openai.com] [env: OATMEAL_OPENAI_URL=]
--open-ai-token <open-ai-token>
OpenAI API token when using the OpenAI backend. [env: OATMEAL_OPENAI_TOKEN=]
--gemini-token <gemini-token>
Gemini API token when using the Gemini backend. [env: OATMEAL_GEMINI_TOKEN=]
-h, --help
Print help
-V, --version
Expand Down Expand Up @@ -261,6 +263,7 @@ The following model backends are supported:
- [OpenAI](https://chat.openai.com) (Or any compatible proxy/API)
- [Ollama](https://github.com/jmorganca/ollama)
- [LangChain/LangServe](https://python.langchain.com/docs/langserve) (Experimental)
- [Gemini](https://gemini.google.com) (Experimental)

### Editors

Expand Down
5 changes: 4 additions & 1 deletion config.example.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai]
# The initial backend hosting a model to connect to. [possible values: langchain, ollama, openai, gemini]
backend = "ollama"

# Time to wait in milliseconds before timing out when doing a healthcheck for a backend.
Expand All @@ -22,6 +22,9 @@ ollama-url = "http://localhost:11434"
# OpenAI API URL when using the OpenAI backend. Can be swapped to a compatible proxy.
open-ai-url = "https://api.openai.com"

# Gemini API token when using the Gemini backend.
# gemini-token = ""

# Sets code syntax highlighting theme. [possible values: base16-github, base16-monokai, base16-one-light, base16-onedark, base16-seti]
theme = "base16-onedark"

Expand Down
8 changes: 8 additions & 0 deletions src/application/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,14 @@ pub fn build() -> Command {
.num_args(1)
.help("OpenAI API token when using the OpenAI backend.")
.global(true),
)
.arg(
Arg::new(ConfigKey::GeminiToken.to_string())
.long(ConfigKey::GeminiToken.to_string())
.env("OATMEAL_GEMINI_TOKEN")
.num_args(1)
.help("Google Gemini API token when using the Gemini backend.")
.global(true),
);
}

Expand Down
2 changes: 2 additions & 0 deletions src/configuration/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub enum ConfigKey {
OllamaURL,
OpenAiToken,
OpenAiURL,
GeminiToken,
SessionID,
Theme,
ThemeFile,
Expand Down Expand Up @@ -99,6 +100,7 @@ impl Config {
ConfigKey::OllamaURL => "http://localhost:11434",
ConfigKey::OpenAiToken => "",
ConfigKey::OpenAiURL => "https://api.openai.com",
ConfigKey::GeminiToken => "",
ConfigKey::Theme => "base16-onedark",
ConfigKey::ThemeFile => "",

Expand Down
1 change: 1 addition & 0 deletions src/domain/models/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub enum BackendName {
LangChain,
Ollama,
OpenAI,
Gemini,
}

impl BackendName {
Expand Down
249 changes: 249 additions & 0 deletions src/infrastructure/backends/gemini.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
#[cfg(test)]
#[path = "gemini_test.rs"]
mod tests;

use std::time::Duration;

use anyhow::bail;
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::TryStreamExt;
use serde::Deserialize;
use serde::Serialize;
use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc;
use tokio_util::io::StreamReader;

use crate::configuration::Config;
use crate::configuration::ConfigKey;
use crate::domain::models::Author;
use crate::domain::models::Backend;
use crate::domain::models::BackendName;
use crate::domain::models::BackendPrompt;
use crate::domain::models::BackendResponse;
use crate::domain::models::Event;

fn convert_err(err: reqwest::Error) -> std::io::Error {
let err_msg = err.to_string();
return std::io::Error::new(std::io::ErrorKind::Interrupted, err_msg);
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Model {
name: String,
supported_generation_methods: Vec<String>,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct ModelListResponse {
models: Vec<Model>,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ContentPartsBlob {
mime_type: String,
data: String,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
enum ContentParts {
Text(String),
InlineData(ContentPartsBlob),
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct Content {
role: String,
parts: Vec<ContentParts>,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct CompletionRequest {
contents: Vec<Content>,
}

#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GenerateContentResponse {
text: String,
}

pub struct Gemini {
url: String,
token: String,
timeout: String,
}

impl Default for Gemini {
fn default() -> Gemini {
return Gemini {
url: "https://generativelanguage.googleapis.com".to_string(),
token: Config::get(ConfigKey::GeminiToken),
timeout: Config::get(ConfigKey::BackendHealthCheckTimeout),
};
}
}

#[async_trait]
impl Backend for Gemini {
fn name(&self) -> BackendName {
return BackendName::Gemini;
}

#[allow(clippy::implicit_return)]
async fn health_check(&self) -> Result<()> {
if self.url.is_empty() {
bail!("Gemini URL is not defined");
}
if self.token.is_empty() {
bail!("Gemini token is not defined");
}

let url = format!(
"{url}/v1beta/{model}?key={key}",
url = self.url,
model = Config::get(ConfigKey::Model),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚓 Health checks are failing due to this. I admit configuration should be injected in the backend interface rather than pulled in from a global repository, but I haven't gotten around to doing it. As this is a health check, you could hardcode the default model here and that'd be enough.

Alternatively, you could call Config::set right before calling the healthcheck function in your test and set it to model-1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a possibility to get more info on the test? I am doing the Config::set in the test exactly as you mentioned:
https://github.com/aislasq/oatmeal/blob/5f4fd4f89b0634060e456ef9b3e4144167ddfeee/src/infrastructure/backends/gemini_test.rs#L54

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The github actions test suite is using nextest, this runs tests in parallel process' where memory will be different for each test. You have Config:set being used in one test, but it's gotta be use in all your healthcheck tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, let me check the linter so you can re-run the tests. Thanks!

key = self.token
);

let res = reqwest::Client::new()
.get(&url)
.timeout(Duration::from_millis(self.timeout.parse::<u64>()?))
.send()
.await;

if res.is_err() {
tracing::error!(error = ?res.unwrap_err(), "Gemini is not reachable");
bail!("Gemini is not reachable");
}

let status = res.unwrap().status().as_u16();
if status >= 400 {
tracing::error!(status = status, "Gemini health check failed");
bail!("Gemini health check failed");
}

return Ok(());
}

#[allow(clippy::implicit_return)]
async fn list_models(&self) -> Result<Vec<String>> {
let res = reqwest::Client::new()
.get(format!(
"{url}/v1beta/models?key={key}",
url = self.url,
key = self.token
))
.send()
.await?
.json::<ModelListResponse>()
.await?;

let mut models: Vec<String> = res
.models
.iter()
.filter(|model| {
model
.supported_generation_methods
.contains(&"generateContent".to_string())
})
.map(|model| {
return model.name.to_string();
})
.collect();

models.sort();

return Ok(models);
}

#[allow(clippy::implicit_return)]
async fn get_completion<'a>(
&self,
prompt: BackendPrompt,
tx: &'a mpsc::UnboundedSender<Event>,
) -> Result<()> {
let mut contents: Vec<Content> = vec![];
if !prompt.backend_context.is_empty() {
contents = serde_json::from_str(&prompt.backend_context)?;
}
contents.push(Content {
role: "user".to_string(),
parts: vec![ContentParts::Text(prompt.text)],
});

let req = CompletionRequest {
contents: contents.clone(),
};

let res = reqwest::Client::new()
.post(format!(
"{url}/v1beta/{model}:streamGenerateContent?key={key}",
url = self.url,
model = Config::get(ConfigKey::Model),
key = self.token,
))
.json(&req)
.send()
.await?;

if !res.status().is_success() {
tracing::error!(
status = res.status().as_u16(),
"Failed to make completion request to Gemini"
);
bail!(format!(
"Failed to make completion request to Gemini, {}",
res.status().as_u16()
));
}
let stream = res.bytes_stream().map_err(convert_err);
let mut lines_reader = StreamReader::new(stream).lines();

let mut last_message = "".to_string();
while let Ok(line) = lines_reader.next_line().await {
if line.is_none() {
break;
}

let cleaned_line = line.unwrap().trim().to_string();
if !cleaned_line.starts_with("\"text\":") {
continue;
}

let ores: GenerateContentResponse =
serde_json::from_str(&format!("{{ {text} }}", text = cleaned_line)).unwrap();

if ores.text.is_empty() || ores.text.is_empty() || ores.text == "\n" {
break;
}

last_message += &ores.text;
let msg = BackendResponse {
author: Author::Model,
text: ores.text,
done: false,
context: None,
};
tx.send(Event::BackendPromptResponse(msg))?;
}

contents.push(Content {
role: "model".to_string(),
parts: vec![ContentParts::Text(last_message.clone())],
});

let msg = BackendResponse {
author: Author::Model,
text: "".to_string(),
done: true,
context: Some(serde_json::to_string(&contents)?),
};
tx.send(Event::BackendPromptResponse(msg))?;

return Ok(());
}
}
Loading
Loading