Skip to content

Commit 054b34e

Browse files
committed
misc: refactored VectorStore trait
1 parent 787bfe4 commit 054b34e

File tree

3 files changed

+38
-29
lines changed

3 files changed

+38
-29
lines changed

src/agent/rag/mod.rs

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use anyhow::Result;
22
use async_trait::async_trait;
3+
use naive::NaiveVectorStore;
34
use serde::{Deserialize, Serialize};
45

56
use super::generator::Client;
@@ -23,10 +24,21 @@ pub struct Document {
2324
#[async_trait]
2425
pub trait VectorStore: Send {
2526
#[allow(clippy::borrowed_box)]
26-
fn new_with_generator(generator: Box<dyn Client>) -> Result<Self>
27+
async fn new(embedder: Box<dyn Client>, config: Configuration) -> Result<Self>
2728
where
2829
Self: Sized;
2930

3031
async fn add(&mut self, document: Document) -> Result<()>;
3132
async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<(Document, f64)>>;
3233
}
34+
35+
pub async fn factory(
36+
flavor: &str,
37+
embedder: Box<dyn Client>,
38+
config: Configuration,
39+
) -> Result<Box<dyn VectorStore>> {
40+
match flavor {
41+
"naive" => Ok(Box::new(NaiveVectorStore::new(embedder, config).await?)),
42+
_ => Err(anyhow!("rag flavor '{flavor} not supported yet")),
43+
}
44+
}

src/agent/rag/naive.rs

+22-24
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,39 @@ use async_trait::async_trait;
88
use colored::Colorize;
99
use glob::glob;
1010

11-
use super::{Document, Embeddings, VectorStore};
11+
use super::{Configuration, Document, Embeddings, VectorStore};
1212
use crate::agent::{generator::Client, rag::metrics};
1313

1414
// TODO: integrate other more efficient vector databases.
1515

1616
pub struct NaiveVectorStore {
17+
config: Configuration,
1718
embedder: Box<dyn Client>,
1819
documents: HashMap<String, Document>,
1920
embeddings: HashMap<String, Embeddings>,
2021
}
2122

22-
impl NaiveVectorStore {
23-
// TODO: add persistency
24-
pub async fn from_indexed_path(generator: Box<dyn Client>, path: &str) -> Result<Self> {
25-
let path = std::fs::canonicalize(path)?.display().to_string();
23+
#[async_trait]
24+
impl VectorStore for NaiveVectorStore {
25+
#[allow(clippy::borrowed_box)]
26+
async fn new(embedder: Box<dyn Client>, config: Configuration) -> Result<Self>
27+
where
28+
Self: Sized,
29+
{
30+
// TODO: add persistency
31+
let documents = HashMap::new();
32+
let embeddings = HashMap::new();
33+
let mut store = Self {
34+
config,
35+
documents,
36+
embeddings,
37+
embedder,
38+
};
39+
40+
let path = std::fs::canonicalize(&store.config.path)?
41+
.display()
42+
.to_string();
2643
let expr = format!("{}/**/*.txt", path);
27-
let mut store = NaiveVectorStore::new_with_generator(generator)?;
2844

2945
for path in (glob(&expr)?).flatten() {
3046
let doc_name = path.display();
@@ -39,24 +55,6 @@ impl NaiveVectorStore {
3955

4056
Ok(store)
4157
}
42-
}
43-
44-
#[async_trait]
45-
impl VectorStore for NaiveVectorStore {
46-
#[allow(clippy::borrowed_box)]
47-
fn new_with_generator(embedder: Box<dyn Client>) -> Result<Self>
48-
where
49-
Self: Sized,
50-
{
51-
let documents = HashMap::new();
52-
let embeddings = HashMap::new();
53-
54-
Ok(Self {
55-
documents,
56-
embeddings,
57-
embedder,
58-
})
59-
}
6058

6159
async fn add(&mut self, document: Document) -> Result<()> {
6260
if self.documents.contains_key(&document.name) {

src/agent/state/mod.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use metrics::Metrics;
77
use super::{
88
generator::{Client, Message},
99
namespaces::{self, Namespace},
10-
rag::{naive::NaiveVectorStore, Document, VectorStore},
10+
rag::{Document, VectorStore},
1111
task::Task,
1212
Invocation,
1313
};
@@ -84,12 +84,11 @@ impl State {
8484

8585
// add RAG namespace
8686
let rag: Option<Box<dyn VectorStore>> = if let Some(config) = task.get_rag_config() {
87-
let v_store: NaiveVectorStore =
88-
NaiveVectorStore::from_indexed_path(embedder, &config.path).await?;
87+
let v_store = super::rag::factory("naive", embedder, config).await?;
8988

9089
namespaces.push(namespaces::NAMESPACES.get("rag").unwrap()());
9190

92-
Some(Box::new(v_store))
91+
Some(v_store)
9392
} else {
9493
None
9594
};

0 commit comments

Comments
 (0)