-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ai-explain): add ai-explain api
- Loading branch information
Showing
17 changed files
with
612 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
DROP TABLE ai_explain_cache; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
CREATE TABLE ai_explain_cache ( | ||
id BIGSERIAL PRIMARY KEY, | ||
signature bytea NOT NULL, | ||
highlighted_hash bytea NOT NULL, | ||
language VARCHAR(255), | ||
explanation TEXT, | ||
created_at TIMESTAMP NOT NULL DEFAULT now(), | ||
last_used TIMESTAMP NOT NULL DEFAULT now(), | ||
view_count BIGINT NOT NULL DEFAULT 1, | ||
version BIGINT NOT NULL DEFAULT 1, | ||
thumbs_up BIGINT NOT NULL DEFAULT 0, | ||
thumbs_down BIGINT NOT NULL DEFAULT 0, | ||
UNIQUE(signature, highlighted_hash, version) | ||
); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
use async_openai::{ | ||
config::OpenAIConfig, | ||
types::{ | ||
ChatCompletionRequestMessageArgs, CreateChatCompletionRequest, | ||
CreateChatCompletionRequestArgs, CreateModerationRequestArgs, Role, | ||
}, | ||
Client, | ||
}; | ||
use hmac::{Hmac, Mac}; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_with::{base64::Base64, serde_as}; | ||
use sha2::{Digest, Sha256}; | ||
|
||
use crate::{ | ||
ai::{ | ||
constants::{EXPLAIN_SYSTEM_MESSAGE, MODEL}, | ||
error::AIError, | ||
}, | ||
api::error::ApiError, | ||
settings::SETTINGS, | ||
}; | ||
|
||
pub const AI_EXPLAIN_VERSION: i64 = 1; | ||
|
||
pub type HmacSha256 = Hmac<Sha256>; | ||
|
||
#[serde_as] | ||
#[derive(Serialize, Deserialize, Clone)] | ||
pub struct ExplainRequest { | ||
pub language: Option<String>, | ||
pub sample: String, | ||
#[serde_as(as = "Base64")] | ||
pub signature: Vec<u8>, | ||
pub highlighted: Option<String>, | ||
} | ||
|
||
pub fn verify_explain_request(req: &ExplainRequest) -> Result<(), anyhow::Error> { | ||
if let Some(part) = &req.highlighted { | ||
if !req.sample.contains(part) { | ||
return Err(ApiError::Artificial.into()); | ||
} | ||
} | ||
let mut mac = HmacSha256::new_from_slice( | ||
&SETTINGS | ||
.ai | ||
.as_ref() | ||
.map(|ai| ai.explain_sign_key) | ||
.ok_or(ApiError::Artificial)?, | ||
)?; | ||
|
||
mac.update(req.sample.as_bytes()); | ||
|
||
mac.verify_slice(&req.signature)?; | ||
Ok(()) | ||
} | ||
|
||
pub fn hash_highlighted(to_be_hashed: &str) -> Vec<u8> { | ||
let mut hasher = Sha256::new(); | ||
hasher.update(to_be_hashed.as_bytes()); | ||
hasher.finalize().to_vec() | ||
} | ||
|
||
pub fn get_language(language: &Option<String>) -> &'static str { | ||
match language.as_deref() { | ||
Some("js" | "javascript") => "js", | ||
Some("html") => "html", | ||
Some("css") => "css", | ||
_ => "", | ||
} | ||
} | ||
|
||
pub fn filter_language(language: Option<String>) -> Option<String> { | ||
if get_language(&language).is_empty() { | ||
return None; | ||
} | ||
language | ||
} | ||
|
||
pub async fn prepare_explain_req( | ||
q: ExplainRequest, | ||
client: &Client<OpenAIConfig>, | ||
) -> Result<CreateChatCompletionRequest, AIError> { | ||
let ExplainRequest { | ||
language, | ||
sample, | ||
highlighted, | ||
.. | ||
} = q; | ||
let language = get_language(&language); | ||
let user_prompt = if let Some(highlighted) = highlighted { | ||
format!("Explain the following part: ```{language}\n{highlighted}\n```") | ||
} else { | ||
"Explain the example in detail.".to_string() | ||
}; | ||
let context_prompt = format!( | ||
"Given the following code example is the MDN code example:```{language}\n{sample}\n```" | ||
); | ||
let req = CreateModerationRequestArgs::default() | ||
.input(format!("{user_prompt}\n{context_prompt}")) | ||
.build() | ||
.unwrap(); | ||
let moderation = client.moderations().create(req).await?; | ||
|
||
if moderation.results.iter().any(|r| r.flagged) { | ||
return Err(AIError::FlaggedError); | ||
} | ||
let system_message = ChatCompletionRequestMessageArgs::default() | ||
.role(Role::System) | ||
.content(EXPLAIN_SYSTEM_MESSAGE) | ||
.build() | ||
.unwrap(); | ||
let context_message = ChatCompletionRequestMessageArgs::default() | ||
.role(Role::User) | ||
.content(context_prompt) | ||
.build() | ||
.unwrap(); | ||
let user_message = ChatCompletionRequestMessageArgs::default() | ||
.role(Role::User) | ||
.content(user_prompt) | ||
.build() | ||
.unwrap(); | ||
let req = CreateChatCompletionRequestArgs::default() | ||
.model(MODEL) | ||
.messages(vec![system_message, context_message, user_message]) | ||
.temperature(0.0) | ||
.build()?; | ||
Ok(req) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,5 @@ pub mod ask; | |
pub mod constants; | ||
pub mod embeddings; | ||
pub mod error; | ||
pub mod explain; | ||
pub mod helpers; |
Oops, something went wrong.