From 9785eab520301f275e6489fda10d1cd77c40df51 Mon Sep 17 00:00:00 2001 From: Florian Dieminger Date: Wed, 28 Jun 2023 18:02:20 +0200 Subject: [PATCH] feat(ai-explain): add ai-explain api (#262) * feat(ai-explain): add ai-explain api --- .settings.test.toml | 1 + Cargo.lock | 12 +- Cargo.toml | 8 +- .../down.sql | 1 + .../2023-06-21-200806_ai-explain-cache/up.sql | 14 ++ src/ai/constants.rs | 9 ++ src/ai/explain.rs | 111 +++++++++++++ src/ai/mod.rs | 1 + src/api/ai.rs | 152 +++++++++++++++++- src/api/api_v1.rs | 21 ++- src/db/ai.rs | 138 +++++++++++++--- src/db/model.rs | 26 +++ src/db/schema.rs | 21 +++ src/settings.rs | 3 + tests/api/ai_explain.rs | 128 +++++++++++++++ tests/api/mod.rs | 1 + 16 files changed, 607 insertions(+), 40 deletions(-) create mode 100644 migrations/2023-06-21-200806_ai-explain-cache/down.sql create mode 100644 migrations/2023-06-21-200806_ai-explain-cache/up.sql create mode 100644 src/ai/explain.rs create mode 100644 tests/api/ai_explain.rs diff --git a/.settings.test.toml b/.settings.test.toml index 15a1f2b7..00b4fb77 100644 --- a/.settings.test.toml +++ b/.settings.test.toml @@ -51,3 +51,4 @@ flag_repo = "flags" [ai] limit_reset_duration_in_sec = 5 api_key = "" +explain_sign_key = "kmMAMku9PB/fTtaoLg82KjTvShg8CSZCBUNuJhUz5Pg=" diff --git a/Cargo.lock b/Cargo.lock index a7bdbd60..2d0a279c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -505,9 +505,9 @@ dependencies = [ [[package]] name = "async-openai" -version = "0.11.1" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb81e98a73c697e72e6bd0b92714b00fc0ffa8871beedeb8c14ab4d1e27ff79" +checksum = "207db2dadafe69aeab087dd54263e04ec9b01836a2d75335707b408eb7ad3f40" dependencies = [ "backoff", "base64 0.21.2", @@ -3560,8 +3560,10 @@ dependencies = [ "diesel_migrations", "elasticsearch", "form_urlencoded", + "futures", "futures-util", "harsh", + "hmac", "hostname", "itertools 0.11.0", "jsonwebtoken", @@ -3580,6 +3582,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "serde_with 3.0.0", + "sha2", "slog", "slog-async", "slog-envlogger", @@ -4577,9 +4580,8 @@ dependencies = [ [[package]] name = "tiktoken-rs" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614" +version = "0.4.3" +source = "git+https://github.com/fiji-flo/tiktoken-rs.git#e57d88cedb0b32a0ef570ee13a372a68f3594bbf" dependencies = [ "anyhow", "async-openai", diff --git a/Cargo.toml b/Cargo.toml index 183909a0..73a9cc0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ reqwest = { version = "0.11", features = ["blocking", "json"] } chrono = "0.4" url = "2" base64 = "0.21" +futures = "0.3" futures-util = "0.3" regex = "1" @@ -73,12 +74,17 @@ sentry-actix = "0.31" basket = "0.0.5" async-openai = "0.11" -tiktoken-rs = { version = "0.4.5", features = ["async-openai"] } +tiktoken-rs = { version = "0.4", features = ["async-openai"] } octocrab = "0.25" aes-gcm = { version = "0.10", features = ["default", "std"] } +hmac = "0.12" +sha2 = "0.10" [dev-dependencies] stubr = "0.6" stubr-attributes = "0.6" assert-json-diff = "2" + +[patch.crates-io] +tiktoken-rs = { git = 'https://github.com/fiji-flo/tiktoken-rs.git' } \ No newline at end of file diff --git a/migrations/2023-06-21-200806_ai-explain-cache/down.sql b/migrations/2023-06-21-200806_ai-explain-cache/down.sql new file mode 100644 index 00000000..5d227d21 --- /dev/null +++ b/migrations/2023-06-21-200806_ai-explain-cache/down.sql @@ -0,0 +1 @@ +DROP TABLE ai_explain_cache; diff --git a/migrations/2023-06-21-200806_ai-explain-cache/up.sql b/migrations/2023-06-21-200806_ai-explain-cache/up.sql new file mode 100644 index 00000000..625e40be --- /dev/null +++ b/migrations/2023-06-21-200806_ai-explain-cache/up.sql @@ -0,0 +1,14 @@ +CREATE TABLE ai_explain_cache ( + id BIGSERIAL PRIMARY KEY, + signature bytea NOT NULL, + highlighted_hash bytea NOT NULL, + language VARCHAR(255), + explanation TEXT, + created_at TIMESTAMP NOT NULL DEFAULT now(), + last_used TIMESTAMP NOT NULL DEFAULT now(), + view_count BIGINT NOT NULL DEFAULT 1, + version BIGINT NOT NULL DEFAULT 1, + thumbs_up BIGINT NOT NULL DEFAULT 0, + thumbs_down BIGINT NOT NULL DEFAULT 0, + UNIQUE(signature, highlighted_hash, version) +); \ No newline at end of file diff --git a/src/ai/constants.rs b/src/ai/constants.rs index ef580dbe..3b434b02 100644 --- a/src/ai/constants.rs +++ b/src/ai/constants.rs @@ -1,3 +1,4 @@ +// Whenever changing the model: bump the AI_EXPLAIN_VERSION! pub const MODEL: &str = "gpt-3.5-turbo"; pub const EMBEDDING_MODEL: &str = "text-embedding-ada-002"; @@ -23,3 +24,11 @@ out how this AI works on GitHub! pub const ASK_TOKEN_LIMIT: usize = 4097; pub const ASK_MAX_COMPLETION_TOKENS: usize = 1024; + +// Whenever changing this message: bump the AI_EXPLAIN_VERSION! +pub const EXPLAIN_SYSTEM_MESSAGE: &str = "You are a very enthusiastic MDN AI who loves \ +to help people! Given the following code example from MDN, answer the user's question \ +outputted in markdown format.\ +"; + +pub const AI_EXPLAIN_VERSION: i64 = 1; diff --git a/src/ai/explain.rs b/src/ai/explain.rs new file mode 100644 index 00000000..a1870a53 --- /dev/null +++ b/src/ai/explain.rs @@ -0,0 +1,111 @@ +use async_openai::{ + config::OpenAIConfig, + types::{ + ChatCompletionRequestMessageArgs, CreateChatCompletionRequest, + CreateChatCompletionRequestArgs, CreateModerationRequestArgs, Role, + }, + Client, +}; +use hmac::{Hmac, Mac}; +use serde::{Deserialize, Serialize}; +use serde_with::{base64::Base64, serde_as}; +use sha2::{Digest, Sha256}; + +use crate::{ + ai::{ + constants::{EXPLAIN_SYSTEM_MESSAGE, MODEL}, + error::AIError, + }, + api::error::ApiError, + settings::SETTINGS, +}; + +pub type HmacSha256 = Hmac; + +#[serde_as] +#[derive(Serialize, Deserialize, Clone)] +pub struct ExplainRequest { + pub language: Option, + pub sample: String, + #[serde_as(as = "Base64")] + pub signature: Vec, + pub highlighted: Option, +} + +pub fn verify_explain_request(req: &ExplainRequest) -> Result<(), anyhow::Error> { + if let Some(part) = &req.highlighted { + if !req.sample.contains(part) { + return Err(ApiError::Artificial.into()); + } + } + let mut mac = HmacSha256::new_from_slice( + &SETTINGS + .ai + .as_ref() + .map(|ai| ai.explain_sign_key) + .ok_or(ApiError::Artificial)?, + )?; + + mac.update(req.language.clone().unwrap_or_default().as_bytes()); + mac.update(req.sample.as_bytes()); + + mac.verify_slice(&req.signature)?; + Ok(()) +} + +pub fn hash_highlighted(to_be_hashed: &str) -> Vec { + let mut hasher = Sha256::new(); + hasher.update(to_be_hashed.as_bytes()); + hasher.finalize().to_vec() +} + +pub async fn prepare_explain_req( + q: ExplainRequest, + client: &Client, +) -> Result { + let ExplainRequest { + language, + sample, + highlighted, + .. + } = q; + let language = language.unwrap_or_default(); + let user_prompt = if let Some(highlighted) = highlighted { + format!("Explain the following part: ```{language}\n{highlighted}\n```") + } else { + "Explain the example in detail.".to_string() + }; + let context_prompt = format!( + "Given the following code example is the MDN code example:```{language}\n{sample}\n```" + ); + let req = CreateModerationRequestArgs::default() + .input(format!("{user_prompt}\n{context_prompt}")) + .build() + .unwrap(); + let moderation = client.moderations().create(req).await?; + + if moderation.results.iter().any(|r| r.flagged) { + return Err(AIError::FlaggedError); + } + let system_message = ChatCompletionRequestMessageArgs::default() + .role(Role::System) + .content(EXPLAIN_SYSTEM_MESSAGE) + .build() + .unwrap(); + let context_message = ChatCompletionRequestMessageArgs::default() + .role(Role::User) + .content(context_prompt) + .build() + .unwrap(); + let user_message = ChatCompletionRequestMessageArgs::default() + .role(Role::User) + .content(user_prompt) + .build() + .unwrap(); + let req = CreateChatCompletionRequestArgs::default() + .model(MODEL) + .messages(vec![system_message, context_message, user_message]) + .temperature(0.0) + .build()?; + Ok(req) +} diff --git a/src/ai/mod.rs b/src/ai/mod.rs index 496ab16d..d55e33b6 100644 --- a/src/ai/mod.rs +++ b/src/ai/mod.rs @@ -2,4 +2,5 @@ pub mod ask; pub mod constants; pub mod embeddings; pub mod error; +pub mod explain; pub mod helpers; diff --git a/src/api/ai.rs b/src/api/ai.rs index f242dc3d..fc77a95f 100644 --- a/src/api/ai.rs +++ b/src/api/ai.rs @@ -3,18 +3,30 @@ use actix_web::{ web::{Data, Json}, Either, HttpResponse, Responder, }; -use actix_web_lab::sse; +use actix_web_lab::{__reexports::tokio::sync::mpsc, sse}; use async_openai::{ - config::OpenAIConfig, error::OpenAIError, types::ChatCompletionRequestMessage, Client, + config::OpenAIConfig, + error::OpenAIError, + types::{ChatCompletionRequestMessage, CreateChatCompletionStreamResponse}, + Client, }; use futures_util::{stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use serde_json::json; +use serde_with::{base64::Base64, serde_as}; use crate::{ - ai::ask::{prepare_ask_req, RefDoc}, + ai::{ + ask::{prepare_ask_req, RefDoc}, + constants::AI_EXPLAIN_VERSION, + explain::{hash_highlighted, prepare_explain_req, verify_explain_request, ExplainRequest}, + }, db::{ - ai::{create_or_increment_total, get_count, AI_HELP_LIMIT}, + ai::{ + add_explain_answer, create_or_increment_total, explain_from_cache, get_count, + set_explain_feedback, ExplainFeedback, AI_HELP_LIMIT, + }, + model::AIExplainCacheInsert, SupaPool, }, }; @@ -64,6 +76,44 @@ pub struct AskMeta { pub quota: Option, } +#[derive(Serialize)] +pub struct CachedChunkDelta { + pub content: String, +} + +#[derive(Serialize)] +pub struct CachedChunkChoice { + pub delta: CachedChunkDelta, +} +#[derive(Serialize)] +pub struct CachedChunk { + pub choices: Vec, +} + +#[serde_as] +#[derive(Serialize)] +pub struct ExplainInitialData { + cached: bool, + #[serde_as(as = "Base64")] + hash: Vec, +} +#[derive(Serialize)] +pub struct ExplainInitial { + initial: ExplainInitialData, +} + +impl From<&str> for CachedChunk { + fn from(content: &str) -> Self { + CachedChunk { + choices: vec![CachedChunkChoice { + delta: CachedChunkDelta { + content: content.into(), + }, + }], + } + } +} + pub async fn quota(user_id: Identity, diesel_pool: Data) -> Result { let mut conn = diesel_pool.get()?; let user = get_user(&mut conn, user_id.id().unwrap())?; @@ -117,3 +167,97 @@ pub async fn ask( } Ok(Either::Right(HttpResponse::NotImplemented().finish())) } + +pub async fn explain_feedback( + diesel_pool: Data, + req: Json, +) -> Result { + let mut conn = diesel_pool.get()?; + set_explain_feedback(&mut conn, req.into_inner())?; + Ok(HttpResponse::Created().finish()) +} + +pub async fn explain( + openai_client: Data>>, + diesel_pool: Data, + req: Json, +) -> Result, ApiError> { + let explain_request = req.into_inner(); + + if verify_explain_request(&explain_request).is_err() { + return Err(ApiError::Unauthorized); + } + let signature = explain_request.signature.clone(); + let to_be_hashed = if let Some(ref highlighted) = explain_request.highlighted { + highlighted + } else { + &explain_request.sample + }; + let highlighted_hash = hash_highlighted(to_be_hashed.as_str()); + let hash = highlighted_hash.clone(); + let language = explain_request.language.clone(); + + let mut conn = diesel_pool.get()?; + if let Some(hit) = explain_from_cache(&mut conn, &signature, &highlighted_hash)? { + if let Some(explanation) = hit.explanation { + let parts = vec![ + sse::Data::new_json(ExplainInitial { + initial: ExplainInitialData { cached: true, hash }, + }) + .map_err(OpenAIError::JSONDeserialize)?, + sse::Data::new_json(CachedChunk::from(explanation.as_str())) + .map_err(OpenAIError::JSONDeserialize)?, + ]; + let stream = futures::stream::iter(parts.into_iter()); + return Ok(Either::Left(sse::Sse::from_stream( + stream.map(|r| Ok::<_, ApiError>(sse::Event::Data(r))), + ))); + } + } + if let Some(client) = &**openai_client { + let explain_req = prepare_explain_req(explain_request, client).await?; + let stream = client.chat().create_stream(explain_req).await.unwrap(); + + let (tx, mut rx) = mpsc::unbounded_channel::(); + + actix_web::rt::spawn(async move { + let mut answer = vec![]; + while let Some(mut chunk) = rx.recv().await { + if let Some(part) = chunk.choices.pop().and_then(|c| c.delta.content) { + answer.push(part); + } + } + let insert = AIExplainCacheInsert { + language, + signature, + highlighted_hash, + explanation: Some(answer.join("")), + version: AI_EXPLAIN_VERSION, + }; + if let Err(err) = add_explain_answer(&mut conn, &insert) { + error!("AI Explain cache: {err}"); + } + }); + let initial = stream::once(async move { + Ok::<_, OpenAIError>(sse::Event::Data( + sse::Data::new_json(ExplainInitial { + initial: ExplainInitialData { + cached: false, + hash, + }, + }) + .map_err(OpenAIError::JSONDeserialize)?, + )) + }); + + return Ok(Either::Right(sse::Sse::from_stream(initial.chain( + stream.map_ok(move |res| { + if let Err(e) = tx.send(res.clone()) { + error!("{e}"); + } + sse::Event::Data(sse::Data::new_json(res).unwrap()) + }), + )))); + } + Err(ApiError::Artificial) +} diff --git a/src/api/api_v1.rs b/src/api/api_v1.rs index 7dbb4979..846fc47c 100644 --- a/src/api/api_v1.rs +++ b/src/api/api_v1.rs @@ -1,4 +1,4 @@ -use crate::api::ai::{ask, quota}; +use crate::api::ai::{ask, explain, explain_feedback, quota}; use crate::api::newsletter::{ is_subscribed, subscribe_anonymous_handler, subscribe_handler, unsubscribe_handler, }; @@ -19,11 +19,20 @@ pub fn api_v1_service() -> impl HttpServiceFactory { .service( web::scope("/plus") .service( - web::scope("/ai").service( - web::scope("/ask") - .service(web::resource("").route(web::post().to(ask))) - .service(web::resource("/quota").route(web::get().to(quota))), - ), + web::scope("/ai") + .service( + web::scope("/ask") + .service(web::resource("").route(web::post().to(ask))) + .service(web::resource("/quota").route(web::get().to(quota))), + ) + .service( + web::scope("/explain") + .service(web::resource("").route(web::post().to(explain))) + .service( + web::resource("/feedback") + .route(web::post().to(explain_feedback)), + ), + ), ) .service(web::resource("/settings/").route(web::post().to(update_settings))) .service( diff --git a/src/db/ai.rs b/src/db/ai.rs index aca44b19..c1597860 100644 --- a/src/db/ai.rs +++ b/src/db/ai.rs @@ -1,12 +1,15 @@ use chrono::{Duration, Utc}; -use diesel::prelude::*; use diesel::{insert_into, PgConnection}; +use diesel::{prelude::*, update}; use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use serde_with::{base64::Base64, serde_as}; +use crate::ai::constants::AI_EXPLAIN_VERSION; use crate::db::error::DbError; -use crate::db::model::{AIHelpLimitInsert, UserQuery}; -use crate::db::schema; -use crate::db::schema::ai_help_limits::*; +use crate::db::model::{AIExplainCacheInsert, AIExplainCacheQuery, AIHelpLimitInsert, UserQuery}; +use crate::db::schema::ai_explain_cache as explain; +use crate::db::schema::ai_help_limits as limits; use crate::settings::SETTINGS; pub const AI_HELP_LIMIT: i64 = 5; @@ -19,11 +22,31 @@ static AI_HELP_RESET_DURATION: Lazy = Lazy::new(|| { ) }); +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FeedbackTyp { + ThumbsDown, + ThumbsUp, +} +#[serde_as] +#[derive(Serialize, Deserialize)] +pub struct ExplainFeedback { + pub typ: FeedbackTyp, + #[serde_as(as = "Base64")] + pub hash: Vec, + #[serde_as(as = "Base64")] + pub signature: Vec, +} + pub fn get_count(conn: &mut PgConnection, user: &UserQuery) -> Result { let some_time_ago = Utc::now().naive_utc() - *AI_HELP_RESET_DURATION; - schema::ai_help_limits::table - .filter(user_id.eq(&user.id).and(latest_start.gt(some_time_ago))) - .select(session_questions) + limits::table + .filter( + limits::user_id + .eq(&user.id) + .and(limits::latest_start.gt(some_time_ago)), + ) + .select(limits::session_questions) .first(conn) .optional() .map(|n| n.unwrap_or(0)) @@ -37,11 +60,11 @@ pub fn create_or_increment_total(conn: &mut PgConnection, user: &UserQuery) -> R session_questions: 0, total_questions: 1, }; - insert_into(schema::ai_help_limits::table) + insert_into(limits::table) .values(&limit) - .on_conflict(schema::ai_help_limits::user_id) + .on_conflict(limits::user_id) .do_update() - .set(((total_questions.eq(total_questions + 1)),)) + .set(((limits::total_questions.eq(limits::total_questions + 1)),)) .execute(conn)?; Ok(()) } @@ -60,19 +83,19 @@ pub fn create_or_increment_limit( let some_time_ago = now - *AI_HELP_RESET_DURATION; // increment num_question if within limit let current = diesel::query_dsl::methods::FilterDsl::filter( - insert_into(schema::ai_help_limits::table) + insert_into(limits::table) .values(&limit) - .on_conflict(schema::ai_help_limits::user_id) + .on_conflict(limits::user_id) .do_update() .set(( - session_questions.eq(session_questions + 1), - (total_questions.eq(total_questions + 1)), + limits::session_questions.eq(limits::session_questions + 1), + (limits::total_questions.eq(limits::total_questions + 1)), )), - session_questions + limits::session_questions .lt(AI_HELP_LIMIT) - .and(latest_start.gt(some_time_ago)), + .and(limits::latest_start.gt(some_time_ago)), ) - .returning(session_questions) + .returning(limits::session_questions) .get_result(conn) .optional()?; if let Some(current) = current { @@ -80,20 +103,87 @@ pub fn create_or_increment_limit( } else { // reset if latest_start is old enough let current = diesel::query_dsl::methods::FilterDsl::filter( - insert_into(schema::ai_help_limits::table) + insert_into(limits::table) .values(&limit) - .on_conflict(schema::ai_help_limits::user_id) + .on_conflict(limits::user_id) .do_update() .set(( - session_questions.eq(1), - (latest_start.eq(now)), - (total_questions.eq(total_questions + 1)), + limits::session_questions.eq(1), + (limits::latest_start.eq(now)), + (limits::total_questions.eq(limits::total_questions + 1)), )), - latest_start.le(some_time_ago), + limits::latest_start.le(some_time_ago), ) - .returning(session_questions) + .returning(limits::session_questions) .get_result(conn) .optional()?; Ok(current) } } + +pub fn add_explain_answer( + conn: &mut PgConnection, + cache: &AIExplainCacheInsert, +) -> Result<(), DbError> { + insert_into(explain::table) + .values(cache) + .on_conflict_do_nothing() + .execute(conn)?; + Ok(()) +} + +pub fn explain_from_cache( + conn: &mut PgConnection, + signature: &Vec, + highlighted_hash: &Vec, +) -> Result, DbError> { + let hit = update(explain::table) + .filter( + explain::signature + .eq(signature) + .and(explain::highlighted_hash.eq(highlighted_hash)) + .and(explain::version.eq(AI_EXPLAIN_VERSION)), + ) + .set(( + explain::last_used.eq(Utc::now().naive_utc()), + explain::view_count.eq(explain::view_count + 1), + )) + .returning(explain::all_columns) + .get_result(conn) + .optional()?; + Ok(hit) +} + +pub fn set_explain_feedback( + conn: &mut PgConnection, + feedback: ExplainFeedback, +) -> Result<(), DbError> { + let ExplainFeedback { + typ, + hash, + signature, + } = feedback; + match typ { + FeedbackTyp::ThumbsDown => update(explain::table) + .filter( + explain::signature + .eq(signature) + .and(explain::highlighted_hash.eq(hash)) + .and(explain::version.eq(AI_EXPLAIN_VERSION)), + ) + .set(explain::thumbs_down.eq(explain::thumbs_down + 1)) + .execute(conn) + .optional()?, + FeedbackTyp::ThumbsUp => update(explain::table) + .filter( + explain::signature + .eq(signature) + .and(explain::highlighted_hash.eq(hash)) + .and(explain::version.eq(AI_EXPLAIN_VERSION)), + ) + .set(explain::thumbs_up.eq(explain::thumbs_up + 1)) + .execute(conn) + .optional()?, + }; + Ok(()) +} diff --git a/src/db/model.rs b/src/db/model.rs index ee95835f..3493291f 100644 --- a/src/db/model.rs +++ b/src/db/model.rs @@ -197,3 +197,29 @@ pub struct AIHelpLimitInsert { pub session_questions: i64, pub total_questions: i64, } + +#[derive(Insertable, Serialize, Debug, Default)] +#[diesel(table_name = ai_explain_cache)] +pub struct AIExplainCacheInsert { + pub language: Option, + pub highlighted_hash: Vec, + pub signature: Vec, + pub explanation: Option, + pub version: i64, +} + +#[derive(Queryable, Serialize, Debug, Default)] +#[diesel(table_name = ai_explain_cache)] +pub struct AIExplainCacheQuery { + pub id: i64, + pub signature: Vec, + pub highlighted_hash: Vec, + pub language: Option, + pub explanation: Option, + pub created_at: NaiveDateTime, + pub last_used: NaiveDateTime, + pub view_count: i64, + pub version: i64, + pub thumbs_up: i64, + pub thumbs_down: i64, +} diff --git a/src/db/schema.rs b/src/db/schema.rs index 195319df..3545618a 100644 --- a/src/db/schema.rs +++ b/src/db/schema.rs @@ -38,6 +38,26 @@ diesel::table! { } } +diesel::table! { + use diesel::sql_types::*; + use crate::db::types::*; + + ai_explain_cache (id) { + id -> Int8, + signature -> Bytea, + highlighted_hash -> Bytea, + #[max_length = 255] + language -> Nullable, + explanation -> Nullable, + created_at -> Timestamp, + last_used -> Timestamp, + view_count -> Int8, + version -> Int8, + thumbs_up -> Int8, + thumbs_down -> Int8, + } +} + diesel::table! { use diesel::sql_types::*; use crate::db::types::*; @@ -256,6 +276,7 @@ diesel::joinable!(settings -> users (user_id)); diesel::allow_tables_to_appear_in_same_query!( activity_pings, + ai_explain_cache, ai_help_limits, bcd_features, bcd_updates, diff --git a/src/settings.rs b/src/settings.rs index 79829f7b..82c57ba0 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -74,10 +74,13 @@ pub struct Basket { pub basket_url: Url, } +#[serde_as] #[derive(Debug, Deserialize)] pub struct AI { pub api_key: String, pub limit_reset_duration_in_sec: i64, + #[serde_as(as = "Base64")] + pub explain_sign_key: [u8; 32], } #[serde_as] diff --git a/tests/api/ai_explain.rs b/tests/api/ai_explain.rs new file mode 100644 index 00000000..8ef6a812 --- /dev/null +++ b/tests/api/ai_explain.rs @@ -0,0 +1,128 @@ +use crate::helpers::app::test_app_with_login; +use crate::helpers::db::{get_pool, reset}; +use crate::helpers::wait_for_stubr; +use actix_web::test; +use anyhow::Error; +use diesel::{QueryDsl, RunQueryDsl}; +use hmac::Mac; +use rumba::ai::constants::AI_EXPLAIN_VERSION; +use rumba::ai::explain::{hash_highlighted, ExplainRequest, HmacSha256}; +use rumba::db::ai::{add_explain_answer, ExplainFeedback, FeedbackTyp}; +use rumba::db::model::{AIExplainCacheInsert, AIExplainCacheQuery}; +use rumba::db::schema::ai_explain_cache; +use rumba::settings::SETTINGS; + +const JS_SAMPLE: &str = "const foo = 1;"; + +fn sign(language: &str, sample: &str) -> Result, Error> { + let mut mac = HmacSha256::new_from_slice( + &SETTINGS + .ai + .as_ref() + .map(|ai| ai.explain_sign_key) + .expect("missing sign_key"), + )?; + + mac.update(language.as_bytes()); + mac.update(sample.as_bytes()); + + Ok(mac.finalize().into_bytes().to_vec()) +} + +fn add_explain_cache() -> Result<(), Error> { + let insert = AIExplainCacheInsert { + language: Some("js".to_owned()), + signature: sign("js", JS_SAMPLE)?, + highlighted_hash: hash_highlighted(JS_SAMPLE), + explanation: Some("Explain this!".to_owned()), + version: AI_EXPLAIN_VERSION, + }; + let pool = get_pool(); + let mut conn = pool.get()?; + add_explain_answer(&mut conn, &insert)?; + Ok(()) +} + +#[actix_rt::test] +#[stubr::mock(port = 4321)] +async fn test_explain() -> Result<(), Error> { + let pool = reset()?; + add_explain_cache()?; + wait_for_stubr().await?; + let app = test_app_with_login(&pool).await.unwrap(); + let service = test::init_service(app).await; + let request = test::TestRequest::post() + .uri("/api/v1/plus/ai/explain") + .set_json(ExplainRequest { + language: Some("js".to_owned()), + sample: JS_SAMPLE.to_owned(), + signature: sign("js", JS_SAMPLE)?, + highlighted: Some(JS_SAMPLE.to_owned()), + }) + .to_request(); + let explain = test::call_service(&service, request).await; + + assert!(explain.status().is_success()); + + let expected = "data: {\"initial\":{\"cached\":true,\"hash\":\"nW77myAksS9XEAZpmXYHPFbW3WZTQvZLLO1cAwPTKwQ=\"}}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\"Explain this!\"}}]}\n\n"; + assert_eq!( + expected, + String::from_utf8_lossy(test::read_body(explain).await.as_ref()) + ); + + let request = test::TestRequest::post() + .uri("/api/v1/plus/ai/explain/feedback") + .set_json(ExplainFeedback { + typ: FeedbackTyp::ThumbsUp, + signature: sign("js", JS_SAMPLE)?, + hash: hash_highlighted(JS_SAMPLE), + }) + .to_request(); + let feedback = test::call_service(&service, request).await; + assert!(feedback.status().is_success()); + + let mut conn = pool.get()?; + let row: AIExplainCacheQuery = ai_explain_cache::table + .select(ai_explain_cache::all_columns) + .first(&mut conn)?; + assert_eq!(row.thumbs_up, 1); + assert_eq!(row.thumbs_down, 0); + assert_eq!(row.view_count, 2); + + let request = test::TestRequest::post() + .uri("/api/v1/plus/ai/explain/feedback") + .set_json(ExplainFeedback { + typ: FeedbackTyp::ThumbsDown, + signature: sign("js", JS_SAMPLE)?, + hash: hash_highlighted(JS_SAMPLE), + }) + .to_request(); + let feedback = test::call_service(&service, request).await; + assert!(feedback.status().is_success()); + + let mut conn = pool.get()?; + let row: AIExplainCacheQuery = ai_explain_cache::table + .select(ai_explain_cache::all_columns) + .first(&mut conn)?; + assert_eq!(row.thumbs_up, 1); + assert_eq!(row.thumbs_down, 1); + + let request = test::TestRequest::post() + .uri("/api/v1/plus/ai/explain/feedback") + .set_json(ExplainFeedback { + typ: FeedbackTyp::ThumbsDown, + signature: sign("js", JS_SAMPLE)?, + hash: hash_highlighted("foo"), + }) + .to_request(); + let feedback = test::call_service(&service, request).await; + assert!(feedback.status().is_success()); + let row: AIExplainCacheQuery = ai_explain_cache::table + .select(ai_explain_cache::all_columns) + .first(&mut conn)?; + assert_eq!(row.thumbs_up, 1); + assert_eq!(row.thumbs_down, 1); + + drop(stubr); + Ok(()) +} diff --git a/tests/api/mod.rs b/tests/api/mod.rs index f5b04604..d81f6dd6 100644 --- a/tests/api/mod.rs +++ b/tests/api/mod.rs @@ -1,3 +1,4 @@ +mod ai_explain; mod ai_help; mod auth; mod fxa_webhooks;