Skip to content

Commit

Permalink
enhance(ai-help): record embedding duration and model separately (#458)
Browse files Browse the repository at this point in the history
* enhance(ai-help): record embedding duration and model separately

* refactor(ai-help): extract default_meta_duration

* Apply suggestions from code review

Co-authored-by: Andi Pieper <[email protected]>

---------

Co-authored-by: Andi Pieper <[email protected]>
  • Loading branch information
caugner and argl authored Apr 4, 2024
1 parent 706ce23 commit 20760b2
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE ai_help_message_meta
DROP COLUMN embedding_duration,
DROP COLUMN embedding_model;
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE ai_help_message_meta
ADD COLUMN embedding_duration BIGINT DEFAULT NULL,
ADD COLUMN embedding_model TEXT NOT NULL DEFAULT '';
27 changes: 26 additions & 1 deletion src/ai/embeddings.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::time::Instant;

use async_openai::{config::OpenAIConfig, types::CreateEmbeddingRequestArgs, Client};
use itertools::Itertools;

use crate::{
ai::{constants::EMBEDDING_MODEL, error::AIError},
ai::{constants::EMBEDDING_MODEL, error::AIError, help::AIHelpRequestMeta},
db::SupaPool,
};

Expand Down Expand Up @@ -69,23 +71,30 @@ pub async fn get_related_macro_docs(
client: &Client<OpenAIConfig>,
pool: &SupaPool,
prompt: String,
request_meta: &mut AIHelpRequestMeta,
) -> Result<Vec<RelatedDoc>, AIError> {
request_meta.embedding_model = Some(EMBEDDING_MODEL);

let embedding_req = CreateEmbeddingRequestArgs::default()
.model(EMBEDDING_MODEL)
.input(prompt)
.build()?;
let start = Instant::now();
let embedding_res = client.embeddings().create(embedding_req).await?;
request_meta.embedding_duration = Some(start.elapsed());

let embedding =
pgvector::Vector::from(embedding_res.data.into_iter().next().unwrap().embedding);

let start = Instant::now();
let mut docs: Vec<RelatedDoc> = sqlx::query_as(MACRO_DOCS_QUERY)
.bind(embedding)
.bind(MACRO_EMB_DISTANCE)
.bind(MACRO_EMB_DOC_LIMIT)
.bind(MACRO_EMB_SEC_MIN_LENGTH)
.fetch_all(pool)
.await?;
request_meta.search_duration = Some(start.elapsed());

let duplicate_titles: Vec<String> = docs
.iter()
Expand All @@ -108,44 +117,60 @@ pub async fn get_related_full_docs(
client: &Client<OpenAIConfig>,
pool: &SupaPool,
prompt: String,
request_meta: &mut AIHelpRequestMeta,
) -> Result<Vec<RelatedDoc>, AIError> {
request_meta.embedding_model = Some(EMBEDDING_MODEL);

let embedding_req = CreateEmbeddingRequestArgs::default()
.model(EMBEDDING_MODEL)
.input(prompt)
.build()?;
let start = Instant::now();
let embedding_res = client.embeddings().create(embedding_req).await?;
request_meta.embedding_duration = Some(start.elapsed());

let embedding =
pgvector::Vector::from(embedding_res.data.into_iter().next().unwrap().embedding);
let start = Instant::now();
let docs: Vec<RelatedDoc> = sqlx::query_as(FULL_DOCS_QUERY)
.bind(embedding)
.bind(FULL_EMB_DISTANCE)
.bind(FULL_EMB_DOC_LIMIT)
.bind(FULL_EMB_SEC_MIN_LENGTH)
.fetch_all(pool)
.await?;
request_meta.search_duration = Some(start.elapsed());

Ok(docs)
}

pub async fn get_related_docs(
client: &Client<OpenAIConfig>,
pool: &SupaPool,
prompt: String,
request_meta: &mut AIHelpRequestMeta,
) -> Result<Vec<RelatedDoc>, AIError> {
request_meta.embedding_model = Some(EMBEDDING_MODEL);

let embedding_req = CreateEmbeddingRequestArgs::default()
.model(EMBEDDING_MODEL)
.input(prompt)
.build()?;
let start = Instant::now();
let embedding_res = client.embeddings().create(embedding_req).await?;
request_meta.embedding_duration = Some(start.elapsed());

let embedding =
pgvector::Vector::from(embedding_res.data.into_iter().next().unwrap().embedding);
let start = Instant::now();
let docs: Vec<RelatedDoc> = sqlx::query_as(DEFAULT_QUERY)
.bind(embedding)
.bind(DEFAULT_EMB_DISTANCE)
.bind(DEFAULT_EMB_DOC_LIMIT)
.bind(DEFAULT_EMB_SEC_MIN_LENGTH)
.fetch_all(pool)
.await?;
request_meta.search_duration = Some(start.elapsed());

Ok(docs)
}
22 changes: 17 additions & 5 deletions src/ai/help.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::{Duration, Instant};
use std::time::Duration;

use async_openai::{
config::OpenAIConfig,
Expand Down Expand Up @@ -39,7 +39,9 @@ pub struct AIHelpRequest {
pub struct AIHelpRequestMeta {
pub query_len: Option<usize>,
pub context_len: Option<usize>,
pub embedding_duration: Option<Duration>,
pub search_duration: Option<Duration>,
pub embedding_model: Option<&'static str>,
pub model: Option<&'static str>,
pub sources: Option<Vec<RefDoc>>,
}
Expand Down Expand Up @@ -95,13 +97,23 @@ pub async fn prepare_ai_help_req(
.ok_or(AIError::NoUserPrompt)?;
request_meta.query_len = Some(last_user_message.len());

let start = Instant::now();
let related_docs = if config.full_doc {
get_related_macro_docs(client, pool, last_user_message.replace('\n', " ")).await?
get_related_macro_docs(
client,
pool,
last_user_message.replace('\n', " "),
request_meta,
)
.await?
} else {
get_related_docs(client, pool, last_user_message.replace('\n', " ")).await?
get_related_docs(
client,
pool,
last_user_message.replace('\n', " "),
request_meta,
)
.await?
};
request_meta.search_duration = Some(start.elapsed());

let mut context = vec![];
let mut refs = vec![];
Expand Down
25 changes: 15 additions & 10 deletions src/api/ai_help.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{future, time::Instant};
use std::{
future,
time::{Duration, Instant},
};

use actix_identity::Identity;
use actix_web::{
Expand Down Expand Up @@ -503,15 +506,13 @@ pub async fn ai_help(
message_id,
parent_id,
created_at: Some(created_at.naive_utc()),
search_duration: default_meta_big_int(
ai_help_req_meta.search_duration.map(|d| d.as_millis()),
),
response_duration: default_meta_big_int(Some(
response_duration.as_millis(),
)),
embedding_duration: default_meta_duration(ai_help_req_meta.embedding_duration),
search_duration: default_meta_duration(ai_help_req_meta.search_duration),
response_duration: default_meta_duration(Some(response_duration)),
query_len: default_meta_big_int(ai_help_req_meta.query_len),
context_len: default_meta_big_int(ai_help_req_meta.context_len),
response_len: default_meta_big_int(Some(context.len)),
embedding_model: ai_help_req_meta.embedding_model.unwrap_or_default(),
model: ai_help_req_meta.model.unwrap_or(""),
status,
sources: ai_help_req_meta.sources.as_ref().map(|sources| {
Expand All @@ -537,11 +538,11 @@ pub async fn ai_help(
chat_id,
message_id,
parent_id,
search_duration: default_meta_big_int(
ai_help_req_meta.search_duration.map(|d| d.as_millis()),
),
embedding_duration: default_meta_duration(ai_help_req_meta.embedding_duration),
search_duration: default_meta_duration(ai_help_req_meta.search_duration),
query_len: default_meta_big_int(ai_help_req_meta.query_len),
context_len: default_meta_big_int(ai_help_req_meta.context_len),
embedding_model: ai_help_req_meta.embedding_model.unwrap_or_default(),
model: ai_help_req_meta.model.unwrap_or(""),
status: (&e).into(),
sources: ai_help_req_meta
Expand Down Expand Up @@ -716,3 +717,7 @@ fn qa_check_for_error_trigger(
fn default_meta_big_int(value: Option<impl TryInto<i64>>) -> Option<i64> {
value.and_then(|v| v.try_into().ok())
}

fn default_meta_duration(duration: Option<Duration>) -> Option<i64> {
default_meta_big_int(duration.map(|d| d.as_millis()))
}
6 changes: 5 additions & 1 deletion src/db/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ pub struct AiHelpMessageMetaInsert<'a> {
pub parent_id: Option<Uuid>,
/// Timestamp at which the message failed or finished.
pub created_at: Option<NaiveDateTime>,
/// Time it took to search related content in milliseconds.
/// Time it took to generate the embedding in milliseconds.
pub embedding_duration: Option<i64>,
/// Time it took to search using the embedding in milliseconds.
pub search_duration: Option<i64>,
/// Time it took to generate the answer in milliseconds.
pub response_duration: Option<i64>,
Expand All @@ -337,6 +339,8 @@ pub struct AiHelpMessageMetaInsert<'a> {
pub context_len: Option<i64>,
/// Length of LLM's reply in bytes.
pub response_len: Option<i64>,
/// Model used to generate the embedding.
pub embedding_model: &'a str,
/// Model used to generate the answer.
pub model: &'a str,
/// Status of the message.
Expand Down
2 changes: 2 additions & 0 deletions src/db/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ diesel::table! {
model -> Text,
status -> AiHelpMessageStatus,
sources -> Jsonb,
embedding_duration -> Nullable<Int8>,
embedding_model -> Text,
}
}

Expand Down

0 comments on commit 20760b2

Please sign in to comment.