Skip to content

Commit

Permalink
Added Cohere example [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jul 16, 2024
1 parent 1421ce3 commit f1f650e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Follow the instructions for your database library:
Or check out some examples:

- [Embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/openai/src/main.rs) with OpenAI
- [Binary embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/cohere/src/main.rs) with Cohere
- [Recommendations](https://github.com/pgvector/pgvector-rust/blob/master/examples/disco/src/main.rs) with Disco
- [Bulk loading](https://github.com/pgvector/pgvector-rust/blob/master/examples/loading/src/main.rs) with `COPY`

Expand Down
11 changes: 11 additions & 0 deletions examples/cohere/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "example"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
pgvector = { path = "../..", features = ["postgres"] }
postgres = "0.19"
serde_json = "1"
ureq = { version = "2", features = ["json"] }
65 changes: 65 additions & 0 deletions examples/cohere/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use pgvector::Bit;
use postgres::{Client, NoTls};
use serde_json::Value;
use std::error::Error;

fn main() -> Result<(), Box<dyn Error>> {
let mut client = Client::configure()
.host("localhost")
.dbname("pgvector_example")
.user(std::env::var("USER")?.as_str())
.connect(NoTls)?;

client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
client.execute("DROP TABLE IF EXISTS documents", &[])?;
client.execute("CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding bit(1024))", &[])?;

let input = [
"The dog is barking",
"The cat is purring",
"The bear is growling",
];
let embeddings = fetch_embeddings(&input, "search_document")?;
for (content, embedding) in input.iter().zip(embeddings) {
let embedding = Bit::from_bytes(&embedding);
client.execute("INSERT INTO documents (content, embedding) VALUES ($1, $2)", &[&content, &embedding])?;
}

let query = "forest";
let query_embedding = fetch_embeddings(&[query], "search_query")?;
for row in client.query("SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5", &[&Bit::from_bytes(&query_embedding[0])])? {
let content: &str = row.get(0);
println!("{}", content);
}

Ok(())
}

fn fetch_embeddings(texts: &[&str], input_type: &str) -> Result<Vec<Vec<u8>>, Box<dyn Error>> {
let api_key = std::env::var("CO_API_KEY").or(Err("Set CO_API_KEY"))?;

let response: Value = ureq::post("https://api.cohere.com/v1/embed")
.set("Authorization", &format!("Bearer {}", api_key))
.send_json(ureq::json!({
"texts": texts,
"model": "embed-english-v3.0",
"input_type": input_type,
"embedding_types": &["ubinary"],
}))?
.into_json()?;

let embeddings = response["embeddings"]["ubinary"]
.as_array()
.unwrap()
.iter()
.map(|v| {
v.as_array()
.unwrap()
.iter()
.map(|v| v.as_f64().unwrap() as u8)
.collect()
})
.collect();

Ok(embeddings)
}

0 comments on commit f1f650e

Please sign in to comment.