Skip to content

Commit 098e4de

Browse files
committed
misc: small fix or general refactoring i did not bother commenting
1 parent 71a7c96 commit 098e4de

File tree

5 files changed

+86
-104
lines changed

5 files changed

+86
-104
lines changed

src/agent/mod.rs

+79-66
Original file line numberDiff line numberDiff line change
@@ -179,85 +179,97 @@ impl Agent {
179179
}
180180

181181
pub async fn step(&mut self) -> Result<()> {
182-
let mut mut_state = self.state.lock().await;
182+
let (invocations, options) = {
183+
let mut mut_state = self.state.lock().await;
183184

184-
mut_state.on_step()?;
185+
mut_state.on_step()?;
185186

186-
if self.options.with_stats {
187-
println!("\n{}\n", &mut_state.metrics);
188-
}
187+
if self.options.with_stats {
188+
println!("\n{}\n", &mut_state.metrics);
189+
}
189190

190-
let system_prompt = serialization::state_to_system_prompt(&mut_state)?;
191-
let prompt = mut_state.to_prompt()?;
192-
let history = mut_state.to_chat_history(self.max_history as usize)?;
191+
let system_prompt = serialization::state_to_system_prompt(&mut_state)?;
192+
let prompt = mut_state.to_prompt()?;
193+
let history = mut_state.to_chat_history(self.max_history as usize)?;
194+
let options = Options::new(system_prompt, prompt, history);
193195

194-
let options = Options::new(system_prompt, prompt, history);
196+
self.save_if_needed(&options, false).await?;
195197

196-
self.save_if_needed(&options, false).await?;
198+
// run model inference
199+
let response = self.generator.chat(&options).await?.trim().to_string();
197200

198-
// run model inference
199-
let response = self.generator.chat(&options).await?.trim().to_string();
201+
// parse the model response into invocations
202+
let invocations = serialization::xml::parsing::try_parse(&response)?;
200203

201-
// parse the model response into invocations
202-
let invocations = serialization::xml::parsing::try_parse(&response)?;
204+
// nothing parsed, report the problem to the model
205+
if invocations.is_empty() {
206+
if response.is_empty() {
207+
println!(
208+
"{}: agent did not provide valid instructions: empty response",
209+
"WARNING".bold().red(),
210+
);
203211

204-
// nothing parsed, report the problem to the model
205-
if invocations.is_empty() {
206-
if response.is_empty() {
207-
println!(
208-
"{}: agent did not provide valid instructions: empty response",
209-
"WARNING".bold().red(),
210-
);
211-
212-
mut_state.metrics.errors.empty_responses += 1;
213-
mut_state.add_unparsed_response_to_history(
214-
&response,
215-
"Do not return an empty responses.".to_string(),
216-
);
217-
} else {
218-
println!("\n\n{}\n\n", response.dimmed());
219-
220-
mut_state.metrics.errors.unparsed_responses += 1;
221-
mut_state.add_unparsed_response_to_history(
212+
mut_state.metrics.errors.empty_responses += 1;
213+
mut_state.add_unparsed_response_to_history(
214+
&response,
215+
"Do not return an empty responses.".to_string(),
216+
);
217+
} else {
218+
println!(
219+
"{}: agent did not provide valid instructions: \n\n{}\n\n",
220+
"WARNING".bold().red(),
221+
response.dimmed()
222+
);
223+
224+
mut_state.metrics.errors.unparsed_responses += 1;
225+
mut_state.add_unparsed_response_to_history(
222226
&response,
223227
"I could not parse any valid actions from your response, please correct it according to the instructions.".to_string(),
224228
);
229+
}
230+
} else {
231+
mut_state.metrics.valid_responses += 1;
225232
}
226-
} else {
227-
mut_state.metrics.valid_responses += 1;
228-
}
229233

230-
// to avoid dead locks, is this needed?
231-
drop(mut_state);
234+
(invocations, options)
235+
};
232236

233237
// for each parsed invocation
238+
// NOTE: the MutexGuard is purposedly captured in its own scope in order to avoid
239+
// deadlocks and make its lifespan clearer.
234240
for inv in invocations {
235241
// lookup action
236-
let mut mut_state = self.state.lock().await;
237-
let action = mut_state.get_action(&inv.action);
238-
242+
let action = self.state.lock().await.get_action(&inv.action);
239243
if action.is_none() {
240-
mut_state.metrics.errors.unknown_actions += 1;
241-
// tell the model that the action name is wrong
242-
mut_state.add_error_to_history(
243-
inv.clone(),
244-
format!("'{}' is not a valid action name", inv.action),
245-
);
246-
drop(mut_state);
244+
{
245+
let mut mut_state = self.state.lock().await;
246+
mut_state.metrics.errors.unknown_actions += 1;
247+
// tell the model that the action name is wrong
248+
mut_state.add_error_to_history(
249+
inv.clone(),
250+
format!("'{}' is not a valid action name", inv.action),
251+
);
252+
}
247253
} else {
248254
let action = action.unwrap();
249255
// validate prerequisites
250-
if let Err(err) = self.validate(&inv, &action) {
251-
mut_state.metrics.errors.invalid_actions += 1;
252-
mut_state.add_error_to_history(inv.clone(), err.to_string());
253-
drop(mut_state);
254-
} else {
255-
mut_state.metrics.valid_actions += 1;
256-
drop(mut_state);
256+
let do_exec = {
257+
let mut mut_state = self.state.lock().await;
258+
259+
if let Err(err) = self.validate(&inv, &action) {
260+
mut_state.metrics.errors.invalid_actions += 1;
261+
mut_state.add_error_to_history(inv.clone(), err.to_string());
262+
false
263+
} else {
264+
mut_state.metrics.valid_actions += 1;
265+
true
266+
}
267+
};
257268

258-
// TODO: timeout logic
269+
// TODO: timeout logic
259270

260-
// execute
271+
// execute
272+
if do_exec {
261273
let ret = action
262274
.run(
263275
self.state.clone(),
@@ -266,17 +278,18 @@ impl Agent {
266278
)
267279
.await;
268280

269-
let mut mut_state = self.state.lock().await;
270-
if let Err(error) = ret {
271-
mut_state.metrics.errors.errored_actions += 1;
272-
// tell the model about the error
273-
mut_state.add_error_to_history(inv, error.to_string());
274-
} else {
275-
mut_state.metrics.success_actions += 1;
276-
// tell the model about the output
277-
mut_state.add_success_to_history(inv, ret.unwrap());
281+
{
282+
let mut mut_state = self.state.lock().await;
283+
if let Err(error) = ret {
284+
mut_state.metrics.errors.errored_actions += 1;
285+
// tell the model about the error
286+
mut_state.add_error_to_history(inv, error.to_string());
287+
} else {
288+
mut_state.metrics.success_actions += 1;
289+
// tell the model about the output
290+
mut_state.add_success_to_history(inv, ret.unwrap());
291+
}
278292
}
279-
drop(mut_state);
280293
}
281294
}
282295

src/agent/namespaces/rag/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl Action for Search {
4141
for (doc, score) in &docs {
4242
println!(" * {} ({})", &doc.name, score);
4343
}
44-
println!("");
44+
println!();
4545

4646
Ok(Some(format!(
4747
"Here is some supporting information:\n\n{}",

src/agent/rag/mod.rs

-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,5 @@ pub trait VectorStore: Send {
2828
Self: Sized;
2929

3030
async fn add(&mut self, document: Document) -> Result<()>;
31-
async fn delete(&mut self, doc_name: &str) -> Result<()>;
32-
async fn clear(&mut self) -> Result<()>;
3331
async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>>;
3432
}

src/agent/rag/naive.rs

-20
Original file line numberDiff line numberDiff line change
@@ -82,26 +82,6 @@ impl VectorStore for NaiveVectorStore {
8282
Ok(())
8383
}
8484

85-
async fn delete(&mut self, doc_name: &str) -> Result<()> {
86-
if self.documents.remove(doc_name).is_some() {
87-
self.embeddings.remove(doc_name);
88-
println!("[rag] removed document '{}'", doc_name);
89-
Ok(())
90-
} else {
91-
Err(anyhow!(
92-
"document with name '{}' not found in the index",
93-
doc_name
94-
))
95-
}
96-
}
97-
98-
async fn clear(&mut self) -> Result<()> {
99-
self.embeddings.clear();
100-
self.documents.clear();
101-
println!("[rag] index cleared");
102-
Ok(())
103-
}
104-
10585
async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>> {
10686
println!("[{}] {} (top {})", "rag".bold(), query, top_k);
10787

src/agent/state/mod.rs

+6-15
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use metrics::Metrics;
77
use super::{
88
generator::{Client, Message},
99
namespaces::{self, Namespace},
10-
rag::{self, naive::NaiveVectorStore, Document, VectorStore},
10+
rag::{naive::NaiveVectorStore, Document, VectorStore},
1111
task::Task,
1212
Invocation,
1313
};
@@ -18,12 +18,6 @@ mod history;
1818
mod metrics;
1919
pub(crate) mod storage;
2020

21-
#[allow(clippy::upper_case_acronyms)]
22-
struct RAG {
23-
config: rag::Configuration,
24-
store: Box<dyn VectorStore>,
25-
}
26-
2721
pub struct State {
2822
// the task
2923
task: Box<dyn Task>,
@@ -34,7 +28,7 @@ pub struct State {
3428
// list of executed actions
3529
history: History,
3630
// optional rag engine
37-
rag: Option<RAG>,
31+
rag: Option<Box<dyn VectorStore>>,
3832
// set to true when task is complete
3933
complete: bool,
4034
// runtime metrics
@@ -89,16 +83,13 @@ impl State {
8983
}
9084

9185
// add RAG namespace
92-
let rag = if let Some(config) = task.get_rag_config() {
93-
let v_store =
86+
let rag: Option<Box<dyn VectorStore>> = if let Some(config) = task.get_rag_config() {
87+
let v_store: NaiveVectorStore =
9488
NaiveVectorStore::from_indexed_path(generator.copy()?, &config.path).await?;
9589

9690
namespaces.push(namespaces::NAMESPACES.get("rag").unwrap()());
9791

98-
Some(RAG {
99-
config: config.clone(),
100-
store: Box::new(v_store),
101-
})
92+
Some(Box::new(v_store))
10293
} else {
10394
None
10495
};
@@ -156,7 +147,7 @@ impl State {
156147

157148
pub async fn rag_query(&mut self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>> {
158149
if let Some(rag) = &self.rag {
159-
rag.store.retrieve(query, top_k).await
150+
rag.retrieve(query, top_k).await
160151
} else {
161152
Err(anyhow!("no RAG engine has been configured"))
162153
}

0 commit comments

Comments
 (0)