Skip to content

Commit 6edf82c

Browse files
committed
fix: decoupled generator options from CLI
1 parent 381d2d2 commit 6edf82c

File tree

11 files changed

+164
-158
lines changed

11 files changed

+164
-158
lines changed

src/agent/events/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod channel;
55
pub(crate) use channel::*;
66

77
use super::{
8-
generator::Options,
8+
generator::ChatOptions,
99
state::{metrics::Metrics, storage::StorageType},
1010
Invocation,
1111
};
@@ -20,7 +20,7 @@ pub(crate) enum Event {
2020
prev: Option<String>,
2121
new: Option<String>,
2222
},
23-
StateUpdate(Options),
23+
StateUpdate(ChatOptions),
2424
EmptyResponse,
2525
InvalidResponse(String),
2626
InvalidAction {

src/agent/generator/fireworks.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use async_trait::async_trait;
33

44
use crate::agent::{state::SharedState, Invocation};
55

6-
use super::{openai::OpenAIClient, Client, Options};
6+
use super::{openai::OpenAIClient, Client, ChatOptions};
77

88
pub struct FireworksClient {
99
client: OpenAIClient,
@@ -27,7 +27,7 @@ impl Client for FireworksClient {
2727
async fn chat(
2828
&self,
2929
state: SharedState,
30-
options: &Options,
30+
options: &ChatOptions,
3131
) -> anyhow::Result<(String, Vec<Invocation>)> {
3232
self.client.chat(state, options).await
3333
}

src/agent/generator/groq.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
1313

1414
use crate::agent::{generator::Message, state::SharedState, Invocation};
1515

16-
use super::{Client, Options};
16+
use super::{Client, ChatOptions};
1717

1818
lazy_static! {
1919
static ref RETRY_TIME_PARSER: Regex =
@@ -108,7 +108,7 @@ impl Client for GroqClient {
108108
async fn chat(
109109
&self,
110110
state: SharedState,
111-
options: &Options,
111+
options: &ChatOptions,
112112
) -> anyhow::Result<(String, Vec<Invocation>)> {
113113
let mut chat_history = vec![
114114
groq_api_rs::completion::message::Message::SystemMessage {

src/agent/generator/huggingface.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use async_trait::async_trait;
33

44
use crate::agent::{state::SharedState, Invocation};
55

6-
use super::{openai::OpenAIClient, Client, Options};
6+
use super::{openai::OpenAIClient, ChatOptions, Client};
77

88
pub struct HuggingfaceMessageClient {
99
client: OpenAIClient,
@@ -26,7 +26,7 @@ impl Client for HuggingfaceMessageClient {
2626
async fn chat(
2727
&self,
2828
state: SharedState,
29-
options: &Options,
29+
options: &ChatOptions,
3030
) -> anyhow::Result<(String, Vec<Invocation>)> {
3131
self.client.chat(state, options).await
3232
}

src/agent/generator/mod.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,24 @@ mod ollama;
2020
#[cfg(feature = "openai")]
2121
mod openai;
2222

23+
mod options;
24+
25+
pub use options::*;
26+
2327
lazy_static! {
2428
static ref RETRY_TIME_PARSER: Regex =
2529
Regex::new(r"(?m)^.+try again in (.+)\. Visit.*").unwrap();
2630
static ref CONN_RESET_PARSER: Regex = Regex::new(r"(?m)^.+onnection reset by peer.*").unwrap();
2731
}
2832

2933
#[derive(Clone, Debug, Serialize, Deserialize)]
30-
pub struct Options {
34+
pub struct ChatOptions {
3135
pub system_prompt: String,
3236
pub prompt: String,
3337
pub history: Vec<Message>,
3438
}
3539

36-
impl Options {
40+
impl ChatOptions {
3741
pub fn new(system_prompt: String, prompt: String, history: Vec<Message>) -> Self {
3842
Self {
3943
system_prompt,
@@ -71,7 +75,7 @@ pub trait Client: mini_rag::Embedder + Send + Sync {
7175
async fn chat(
7276
&self,
7377
state: SharedState,
74-
options: &Options,
78+
options: &ChatOptions,
7579
) -> Result<(String, Vec<Invocation>)>;
7680

7781
async fn check_tools_support(&self) -> Result<bool> {

src/agent/generator/ollama.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use ollama_rs::{
1818

1919
use crate::agent::{state::SharedState, Invocation};
2020

21-
use super::{Client, Message, Options};
21+
use super::{ChatOptions, Client, Message};
2222

2323
pub struct OllamaClient {
2424
model: String,
@@ -91,7 +91,7 @@ impl Client for OllamaClient {
9191
async fn chat(
9292
&self,
9393
state: SharedState,
94-
options: &Options,
94+
options: &ChatOptions,
9595
) -> anyhow::Result<(String, Vec<Invocation>)> {
9696
// TODO: images for multimodal (see todo for screenshot action)
9797

src/agent/generator/openai.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
99

1010
use crate::agent::{state::SharedState, Invocation};
1111

12-
use super::{Client, Message, Options};
12+
use super::{Client, Message, ChatOptions};
1313

1414
#[derive(Debug, Clone, Serialize, Deserialize)]
1515
pub struct OpenAiToolFunctionParameterProperty {
@@ -113,7 +113,7 @@ impl Client for OpenAIClient {
113113
async fn chat(
114114
&self,
115115
state: SharedState,
116-
options: &Options,
116+
options: &ChatOptions,
117117
) -> anyhow::Result<(String, Vec<Invocation>)> {
118118
let mut chat_history = vec![
119119
openai_api_rust::Message {

src/agent/generator/options.rs

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
use anyhow::Result;
2+
use lazy_static::lazy_static;
3+
use regex::Regex;
4+
5+
lazy_static! {
6+
static ref PUBLIC_GENERATOR_PARSER: Regex = Regex::new(r"(?m)^(.+)://(.+)$").unwrap();
7+
static ref LOCAL_GENERATOR_PARSER: Regex =
8+
Regex::new(r"(?m)^(.+)://(.+)@([^:]+):?(\d+)?$").unwrap();
9+
}
10+
11+
#[derive(Default)]
12+
pub struct Options {
13+
pub type_name: String,
14+
pub model_name: String,
15+
pub context_window: u32,
16+
pub host: String,
17+
pub port: u16,
18+
}
19+
20+
impl Options {
21+
pub fn parse(raw: &str, context_window: u32) -> Result<Self> {
22+
let raw = raw.trim().trim_matches(|c| c == '"' || c == '\'');
23+
if raw.is_empty() {
24+
return Err(anyhow!("generator string can't be empty"));
25+
}
26+
27+
let mut generator = Options {
28+
context_window,
29+
..Default::default()
30+
};
31+
32+
if raw.contains('@') {
33+
let caps = if let Some(caps) = LOCAL_GENERATOR_PARSER.captures_iter(raw).next() {
34+
caps
35+
} else {
36+
return Err(anyhow!("can't parse '{raw}' generator string"));
37+
};
38+
39+
if caps.len() != 5 {
40+
return Err(anyhow!(
41+
"can't parse {raw} generator string ({} captures instead of 5): {:?}",
42+
caps.len(),
43+
caps,
44+
));
45+
}
46+
47+
caps.get(1)
48+
.unwrap()
49+
.as_str()
50+
.clone_into(&mut generator.type_name);
51+
caps.get(2)
52+
.unwrap()
53+
.as_str()
54+
.clone_into(&mut generator.model_name);
55+
caps.get(3)
56+
.unwrap()
57+
.as_str()
58+
.clone_into(&mut generator.host);
59+
generator.port = if let Some(port) = caps.get(4) {
60+
port.as_str().parse::<u16>().unwrap()
61+
} else {
62+
0
63+
};
64+
} else {
65+
let caps = if let Some(caps) = PUBLIC_GENERATOR_PARSER.captures_iter(raw).next() {
66+
caps
67+
} else {
68+
return Err(anyhow!(
69+
"can't parse {raw} generator string, invalid expression"
70+
));
71+
};
72+
73+
if caps.len() != 3 {
74+
return Err(anyhow!(
75+
"can't parse {raw} generator string, expected 3 captures, got {}",
76+
caps.len()
77+
));
78+
}
79+
80+
caps.get(1)
81+
.unwrap()
82+
.as_str()
83+
.clone_into(&mut generator.type_name);
84+
caps.get(2)
85+
.unwrap()
86+
.as_str()
87+
.clone_into(&mut generator.model_name);
88+
}
89+
90+
Ok(generator)
91+
}
92+
}
93+
94+
#[cfg(test)]
95+
mod tests {
96+
use super::Options;
97+
98+
#[test]
99+
fn test_wont_parse_invalid_generator() {
100+
assert!(Options::parse("not a valid generator", 123).is_err());
101+
}
102+
103+
#[test]
104+
fn test_parse_local_generator_full() {
105+
let ret = Options::parse("ollama://llama3@localhost:11434", 123).unwrap();
106+
107+
assert_eq!(ret.type_name, "ollama");
108+
assert_eq!(ret.model_name, "llama3");
109+
assert_eq!(ret.host, "localhost");
110+
assert_eq!(ret.port, 11434);
111+
assert_eq!(ret.context_window, 123);
112+
}
113+
114+
#[test]
115+
fn test_parse_local_generator_without_port() {
116+
let ret = Options::parse("ollama://llama3@localhost", 123).unwrap();
117+
118+
assert_eq!(ret.type_name, "ollama");
119+
assert_eq!(ret.model_name, "llama3");
120+
assert_eq!(ret.host, "localhost");
121+
assert_eq!(ret.port, 0);
122+
assert_eq!(ret.context_window, 123);
123+
}
124+
125+
#[test]
126+
fn test_parse_public_generator() {
127+
let ret = Options::parse("groq://llama3", 123).unwrap();
128+
129+
assert_eq!(ret.type_name, "groq");
130+
assert_eq!(ret.model_name, "llama3");
131+
assert_eq!(ret.host, "");
132+
assert_eq!(ret.port, 0);
133+
assert_eq!(ret.context_window, 123);
134+
}
135+
}

src/agent/mod.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use mini_rag::Embedder;
55
use serde::{Deserialize, Serialize};
66

77
use events::Event;
8-
use generator::{Client, Options};
8+
use generator::{Client, ChatOptions};
99
use namespaces::Action;
1010
use serialization::xml::serialize;
1111
use state::{SharedState, State};
@@ -199,7 +199,7 @@ impl Agent {
199199
self.state.lock().await.is_complete()
200200
}
201201

202-
async fn on_state_update(&self, options: &Options, refresh: bool) -> Result<()> {
202+
async fn on_state_update(&self, options: &ChatOptions, refresh: bool) -> Result<()> {
203203
let mut opts = options.clone();
204204
if refresh {
205205
opts.system_prompt = serialization::state_to_system_prompt(&*self.state.lock().await)?;
@@ -310,7 +310,7 @@ impl Agent {
310310
self.state.lock().await.metrics.clone()
311311
}
312312

313-
async fn prepare_step(&mut self) -> Result<Options> {
313+
async fn prepare_step(&mut self) -> Result<ChatOptions> {
314314
let mut mut_state = self.state.lock().await;
315315

316316
mut_state.on_step()?;
@@ -320,7 +320,7 @@ impl Agent {
320320
let system_prompt = serialization::state_to_system_prompt(&mut_state)?;
321321
let prompt = mut_state.to_prompt()?;
322322
let history = mut_state.to_chat_history(self.max_history as usize)?;
323-
let options = Options::new(system_prompt, prompt, history);
323+
let options = ChatOptions::new(system_prompt, prompt, history);
324324

325325
Ok(options)
326326
}

0 commit comments

Comments
 (0)