Skip to content

Commit e03512f

Browse files
committed
new: added rayon parallelization for vector search
1 parent 501d09b commit e03512f

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

Cargo.lock

+46
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,17 @@ simple-home-dir = "0.3.5"
3939
tokio = "1.38.0"
4040
xml-rs = "0.8.20"
4141
duration-string = { version = "0.4.0", optional = true }
42+
rayon = { version = "1.10.0", optional = true }
4243
glob = "0.3.1"
4344

4445
[features]
45-
default = ["ollama", "groq", "openai", "fireworks"]
46+
default = ["ollama", "groq", "openai", "fireworks", "rayon"]
4647

4748
ollama = ["dep:ollama-rs"]
4849
groq = ["dep:groq-api-rs", "dep:duration-string"]
4950
openai = ["dep:openai_api_rust"]
5051
fireworks = ["dep:openai_api_rust"]
51-
52+
rayon = ["dep:rayon"]
5253

5354
[profile.release]
5455
lto = true # Enable link-time optimization

src/agent/rag/naive.rs

+20-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use std::{collections::HashMap, time::Instant};
22

3+
#[cfg(feature = "rayon")]
4+
use rayon::prelude::*;
5+
36
use anyhow::Result;
47
use async_trait::async_trait;
58
use colored::Colorize;
@@ -89,13 +92,25 @@ impl VectorStore for NaiveVectorStore {
8992
println!("[{}] {} (top {})", "rag".bold(), query, top_k);
9093

9194
let query_vector = self.embedder.embeddings(query).await?;
92-
let mut distances = vec![];
9395
let mut results = vec![];
9496

95-
// TODO: parallelize?
96-
for (doc_name, doc_embedding) in &self.embeddings {
97-
distances.push((doc_name, metrics::cosine(&query_vector, doc_embedding)));
98-
}
97+
#[cfg(feature = "rayon")]
98+
let mut distances: Vec<(&String, f64)> = self
99+
.embeddings
100+
.par_iter()
101+
.map(|(doc_name, doc_embedding)| {
102+
(doc_name, metrics::cosine(&query_vector, doc_embedding))
103+
})
104+
.collect();
105+
106+
#[cfg(not(feature = "rayon"))]
107+
let mut distances = {
108+
let mut distances = vec![];
109+
for (doc_name, doc_embedding) in &self.embeddings {
110+
distances.push((doc_name, metrics::cosine(&query_vector, doc_embedding)));
111+
}
112+
distances
113+
};
99114

100115
distances.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
101116

0 commit comments

Comments
 (0)