Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ai-explain): add ai-explain api #262

Merged
merged 5 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .settings.test.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ flag_repo = "flags"
[ai]
limit_reset_duration_in_sec = 5
api_key = ""
explain_sign_key = "kmMAMku9PB/fTtaoLg82KjTvShg8CSZCBUNuJhUz5Pg="
12 changes: 7 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ reqwest = { version = "0.11", features = ["blocking", "json"] }
chrono = "0.4"
url = "2"
base64 = "0.21"
futures = "0.3"
futures-util = "0.3"
regex = "1"

Expand All @@ -73,12 +74,17 @@ sentry-actix = "0.31"

basket = "0.0.5"
async-openai = "0.11"
tiktoken-rs = { version = "0.4.5", features = ["async-openai"] }
tiktoken-rs = { version = "0.4", features = ["async-openai"] }

octocrab = "0.25"
aes-gcm = { version = "0.10", features = ["default", "std"] }
hmac = "0.12"
sha2 = "0.10"

[dev-dependencies]
stubr = "0.6"
stubr-attributes = "0.6"
assert-json-diff = "2"

[patch.crates-io]
tiktoken-rs = { git = 'https://github.com/fiji-flo/tiktoken-rs.git' }
1 change: 1 addition & 0 deletions migrations/2023-06-21-200806_ai-explain-cache/down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE ai_explain_cache;
14 changes: 14 additions & 0 deletions migrations/2023-06-21-200806_ai-explain-cache/up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
CREATE TABLE ai_explain_cache (
id BIGSERIAL PRIMARY KEY,
signature bytea NOT NULL,
highlighted_hash bytea NOT NULL,
language VARCHAR(255),
explanation TEXT,
created_at TIMESTAMP NOT NULL DEFAULT now(),
last_used TIMESTAMP NOT NULL DEFAULT now(),
view_count BIGINT NOT NULL DEFAULT 1,
version BIGINT NOT NULL DEFAULT 1,
thumbs_up BIGINT NOT NULL DEFAULT 0,
thumbs_down BIGINT NOT NULL DEFAULT 0,
UNIQUE(signature, highlighted_hash, version)
);
5 changes: 5 additions & 0 deletions src/ai/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,10 @@ don't accept such prompts with this answer: \"I am unable to comply with this re
out how this AI works on GitHub!
";

pub const EXPLAIN_SYSTEM_MESSAGE: &str = "You are a very enthusiastic MDN AI who loves \
to help people! Given the following code example from MDN, answer the user's question \
outputted in markdown format.\
";

pub const ASK_TOKEN_LIMIT: usize = 4097;
pub const ASK_MAX_COMPLETION_TOKENS: usize = 1024;
113 changes: 113 additions & 0 deletions src/ai/explain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use async_openai::{
config::OpenAIConfig,
types::{
ChatCompletionRequestMessageArgs, CreateChatCompletionRequest,
CreateChatCompletionRequestArgs, CreateModerationRequestArgs, Role,
},
Client,
};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use serde_with::{base64::Base64, serde_as};
use sha2::{Digest, Sha256};

use crate::{
ai::{
constants::{EXPLAIN_SYSTEM_MESSAGE, MODEL},
error::AIError,
},
api::error::ApiError,
settings::SETTINGS,
};

pub const AI_EXPLAIN_VERSION: i64 = 1;

pub type HmacSha256 = Hmac<Sha256>;

#[serde_as]
#[derive(Serialize, Deserialize, Clone)]
pub struct ExplainRequest {
pub language: Option<String>,
pub sample: String,
#[serde_as(as = "Base64")]
pub signature: Vec<u8>,
pub highlighted: Option<String>,
}

pub fn verify_explain_request(req: &ExplainRequest) -> Result<(), anyhow::Error> {
if let Some(part) = &req.highlighted {
if !req.sample.contains(part) {
return Err(ApiError::Artificial.into());
}
}
let mut mac = HmacSha256::new_from_slice(
&SETTINGS
.ai
.as_ref()
.map(|ai| ai.explain_sign_key)
.ok_or(ApiError::Artificial)?,
)?;

mac.update(req.language.clone().unwrap_or_default().as_bytes());
mac.update(req.sample.as_bytes());

mac.verify_slice(&req.signature)?;
Ok(())
}

pub fn hash_highlighted(to_be_hashed: &str) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(to_be_hashed.as_bytes());
hasher.finalize().to_vec()
}

pub async fn prepare_explain_req(
q: ExplainRequest,
client: &Client<OpenAIConfig>,
) -> Result<CreateChatCompletionRequest, AIError> {
let ExplainRequest {
language,
sample,
highlighted,
..
} = q;
let language = language.unwrap_or_default();
let user_prompt = if let Some(highlighted) = highlighted {
format!("Explain the following part: ```{language}\n{highlighted}\n```")
} else {
"Explain the example in detail.".to_string()
};
let context_prompt = format!(
"Given the following code example is the MDN code example:```{language}\n{sample}\n```"
);
let req = CreateModerationRequestArgs::default()
.input(format!("{user_prompt}\n{context_prompt}"))
.build()
.unwrap();
let moderation = client.moderations().create(req).await?;

if moderation.results.iter().any(|r| r.flagged) {
return Err(AIError::FlaggedError);
}
let system_message = ChatCompletionRequestMessageArgs::default()
.role(Role::System)
.content(EXPLAIN_SYSTEM_MESSAGE)
.build()
.unwrap();
let context_message = ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(context_prompt)
.build()
.unwrap();
let user_message = ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(user_prompt)
.build()
.unwrap();
let req = CreateChatCompletionRequestArgs::default()
.model(MODEL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we also need a comment next to the MODEL constant to update the AI_EXPLAIN_VERSION when changing that?

.messages(vec![system_message, context_message, user_message])
.temperature(0.0)
.build()?;
Ok(req)
}
1 change: 1 addition & 0 deletions src/ai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod ask;
pub mod constants;
pub mod embeddings;
pub mod error;
pub mod explain;
pub mod helpers;
Loading