Skip to content

Commit

Permalink
feat(ai-explain): add ai-explain api
Browse files Browse the repository at this point in the history
  • Loading branch information
fiji-flo committed Jun 28, 2023
1 parent b878a50 commit af5147d
Show file tree
Hide file tree
Showing 17 changed files with 612 additions and 42 deletions.
1 change: 1 addition & 0 deletions .settings.test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ flag_repo = "flags"
[ai]
limit_reset_duration_in_sec = 5
api_key = ""
explain_sign_key = "kmMAMku9PB/fTtaoLg82KjTvShg8CSZCBUNuJhUz5Pg="
12 changes: 7 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ reqwest = { version = "0.11", features = ["blocking", "json"] }
chrono = "0.4"
url = "2"
base64 = "0.21"
futures = "0.3"
futures-util = "0.3"
regex = "1"

Expand All @@ -73,12 +74,17 @@ sentry-actix = "0.31"

basket = "0.0.5"
async-openai = "0.11"
tiktoken-rs = { version = "0.4.5", features = ["async-openai"] }
tiktoken-rs = { version = "0.4", features = ["async-openai"] }

octocrab = "0.25"
aes-gcm = { version = "0.10", features = ["default", "std"] }
hmac = "0.12"
sha2 = "0.10"

[dev-dependencies]
stubr = "0.6"
stubr-attributes = "0.6"
assert-json-diff = "2"

[patch.crates-io]
tiktoken-rs = { git = 'https://github.com/fiji-flo/tiktoken-rs.git' }
1 change: 1 addition & 0 deletions migrations/2023-06-21-200806_ai-explain-cache/down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE ai_explain_cache;
14 changes: 14 additions & 0 deletions migrations/2023-06-21-200806_ai-explain-cache/up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
CREATE TABLE ai_explain_cache (
id BIGSERIAL PRIMARY KEY,
signature bytea NOT NULL,
highlighted_hash bytea NOT NULL,
language VARCHAR(255),
explanation TEXT,
created_at TIMESTAMP NOT NULL DEFAULT now(),
last_used TIMESTAMP NOT NULL DEFAULT now(),
view_count BIGINT NOT NULL DEFAULT 1,
version BIGINT NOT NULL DEFAULT 1,
thumbs_up BIGINT NOT NULL DEFAULT 0,
thumbs_down BIGINT NOT NULL DEFAULT 0,
UNIQUE(signature, highlighted_hash, version)
);
7 changes: 6 additions & 1 deletion src/ai/constants.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub const MODEL: &str = "gpt-3.5-turbo";
pub const MODEL: &str = "gpt-3.5-turbo-0613";
pub const EMBEDDING_MODEL: &str = "text-embedding-ada-002";

pub const ASK_SYSTEM_MESSAGE: &str = "You are a very enthusiastic MDN AI who loves \
Expand All @@ -21,5 +21,10 @@ don't accept such prompts with this answer: \"I am unable to comply with this re
out how this AI works on GitHub!
";

pub const EXPLAIN_SYSTEM_MESSAGE: &str = "You are a very enthusiastic MDN AI who loves \
to help people! Given the following code example from MDN, answer the user's question \
outputted in markdown format.\
";

pub const ASK_TOKEN_LIMIT: usize = 4097;
pub const ASK_MAX_COMPLETION_TOKENS: usize = 1024;
2 changes: 1 addition & 1 deletion src/ai/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{

const EMB_DISTANCE: f64 = 0.78;
const EMB_SEC_MIN_LENGTH: i64 = 50;
const EMB_DOC_LIMIT: i64 = 5;
const EMB_DOC_LIMIT: i64 = 3;

#[derive(sqlx::FromRow)]
pub struct RelatedDoc {
Expand Down
128 changes: 128 additions & 0 deletions src/ai/explain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use async_openai::{
config::OpenAIConfig,
types::{
ChatCompletionRequestMessageArgs, CreateChatCompletionRequest,
CreateChatCompletionRequestArgs, CreateModerationRequestArgs, Role,
},
Client,
};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use serde_with::{base64::Base64, serde_as};
use sha2::{Digest, Sha256};

use crate::{
ai::{
constants::{EXPLAIN_SYSTEM_MESSAGE, MODEL},
error::AIError,
},
api::error::ApiError,
settings::SETTINGS,
};

pub const AI_EXPLAIN_VERSION: i64 = 1;

pub type HmacSha256 = Hmac<Sha256>;

#[serde_as]
#[derive(Serialize, Deserialize, Clone)]
pub struct ExplainRequest {
pub language: Option<String>,
pub sample: String,
#[serde_as(as = "Base64")]
pub signature: Vec<u8>,
pub highlighted: Option<String>,
}

pub fn verify_explain_request(req: &ExplainRequest) -> Result<(), anyhow::Error> {
if let Some(part) = &req.highlighted {
if !req.sample.contains(part) {
return Err(ApiError::Artificial.into());
}
}
let mut mac = HmacSha256::new_from_slice(
&SETTINGS
.ai
.as_ref()
.map(|ai| ai.explain_sign_key)
.ok_or(ApiError::Artificial)?,
)?;

mac.update(req.sample.as_bytes());

mac.verify_slice(&req.signature)?;
Ok(())
}

pub fn hash_highlighted(to_be_hashed: &str) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(to_be_hashed.as_bytes());
hasher.finalize().to_vec()
}

pub fn get_language(language: &Option<String>) -> &'static str {
match language.as_deref() {
Some("js" | "javascript") => "js",
Some("html") => "html",
Some("css") => "css",
_ => "",
}
}

pub fn filter_language(language: Option<String>) -> Option<String> {
if get_language(&language).is_empty() {
return None;
}
language
}

pub async fn prepare_explain_req(
q: ExplainRequest,
client: &Client<OpenAIConfig>,
) -> Result<CreateChatCompletionRequest, AIError> {
let ExplainRequest {
language,
sample,
highlighted,
..
} = q;
let language = get_language(&language);
let user_prompt = if let Some(highlighted) = highlighted {
format!("Explain the following part: ```{language}\n{highlighted}\n```")
} else {
"Explain the example in detail.".to_string()
};
let context_prompt = format!(
"Given the following code example is the MDN code example:```{language}\n{sample}\n```"
);
let req = CreateModerationRequestArgs::default()
.input(format!("{user_prompt}\n{context_prompt}"))
.build()
.unwrap();
let moderation = client.moderations().create(req).await?;

if moderation.results.iter().any(|r| r.flagged) {
return Err(AIError::FlaggedError);
}
let system_message = ChatCompletionRequestMessageArgs::default()
.role(Role::System)
.content(EXPLAIN_SYSTEM_MESSAGE)
.build()
.unwrap();
let context_message = ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(context_prompt)
.build()
.unwrap();
let user_message = ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(user_prompt)
.build()
.unwrap();
let req = CreateChatCompletionRequestArgs::default()
.model(MODEL)
.messages(vec![system_message, context_message, user_message])
.temperature(0.0)
.build()?;
Ok(req)
}
1 change: 1 addition & 0 deletions src/ai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod ask;
pub mod constants;
pub mod embeddings;
pub mod error;
pub mod explain;
pub mod helpers;
Loading

0 comments on commit af5147d

Please sign in to comment.