Skip to content

Commit 950d5db

Browse files
authored
Merge pull request #26 from Nyamort/main
new: mistral integration
2 parents e023ac4 + 066605c commit 950d5db

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

src/agent/generator/mistral.rs

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use anyhow::Result;
2+
use async_trait::async_trait;
3+
4+
use crate::agent::state::SharedState;
5+
6+
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
7+
8+
pub struct MistralClient {
9+
client: OpenAIClient,
10+
}
11+
12+
#[async_trait]
13+
impl Client for MistralClient {
14+
fn new(_: &str, _: u16, model_name: &str, _: u32) -> anyhow::Result<Self>
15+
where
16+
Self: Sized,
17+
{
18+
let client = OpenAIClient::custom(model_name, "MISTRAL_API_KEY", "https://api.mistral.ai/v1/")?;
19+
20+
Ok(Self { client })
21+
}
22+
23+
async fn check_native_tools_support(&self) -> Result<bool> {
24+
self.client.check_native_tools_support().await
25+
}
26+
27+
async fn chat(
28+
&self,
29+
state: SharedState,
30+
options: &ChatOptions,
31+
) -> anyhow::Result<ChatResponse> {
32+
let response = self.client.chat(state.clone(), options).await;
33+
34+
if let Err(error) = &response {
35+
if self.check_rate_limit(&error.to_string()).await {
36+
return self.chat(state, options).await;
37+
}
38+
}
39+
40+
response
41+
}
42+
43+
async fn check_rate_limit(&self, error: &str) -> bool {
44+
// if message contains "Requests rate limit exceeded" return true
45+
if error.contains("Requests rate limit exceeded") {
46+
let retry_time = std::time::Duration::from_secs(5);
47+
log::warn!(
48+
"rate limit reached for this model, retrying in {:?} ...",
49+
&retry_time,
50+
);
51+
52+
tokio::time::sleep(retry_time).await;
53+
54+
return true;
55+
}
56+
57+
false
58+
}
59+
}
60+
61+
#[async_trait]
62+
impl mini_rag::Embedder for MistralClient {
63+
async fn embed(&self, text: &str) -> Result<mini_rag::Embeddings> {
64+
self.client.embed(text).await
65+
}
66+
}

src/agent/generator/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod ollama;
2020
mod openai;
2121
mod openai_compatible;
2222
mod xai;
23+
mod mistral;
2324

2425
mod options;
2526

@@ -208,6 +209,12 @@ macro_rules! factory_body {
208209
$model_name,
209210
$context_window,
210211
)?)),
212+
"mistral" => Ok(Box::new(mistral::MistralClient::new(
213+
$url,
214+
$port,
215+
$model_name,
216+
$context_window,
217+
)?)),
211218
"http" => Ok(Box::new(openai_compatible::OpenAiCompatibleClient::new(
212219
$url,
213220
$port,

0 commit comments

Comments
 (0)