Skip to content

Commit

Permalink
enhance(ai-help): Don't answer if no refs (#277)
Browse files Browse the repository at this point in the history
* enhance(ai-help): Don't answer if no refs

* fix(ai): reuse specific "Sorry" message

We replace it in the frontend.

---------

Co-authored-by: Claas Augner <[email protected]>
  • Loading branch information
fiji-flo and caugner authored Jul 6, 2023
1 parent c99dfd2 commit 5f9bb64
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 37 deletions.
7 changes: 5 additions & 2 deletions src/ai/ask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub async fn prepare_ask_req(
client: &Client<OpenAIConfig>,
pool: &SupaPool,
messages: Vec<ChatCompletionRequestMessage>,
) -> Result<AskRequest, AIError> {
) -> Result<Option<AskRequest>, AIError> {
let open_ai_messages = sanitize_messages(messages);

// TODO: sign messages os we don't check again
Expand Down Expand Up @@ -91,6 +91,9 @@ pub async fn prepare_ask_req(
});
}
}
if context.is_empty() {
return Ok(None);
}
let context = context.join("\n---\n");
let system_message = ChatCompletionRequestMessageArgs::default()
.role(Role::System)
Expand All @@ -116,5 +119,5 @@ pub async fn prepare_ask_req(
.temperature(0.0)
.build()?;

Ok(AskRequest { req, refs })
Ok(Some(AskRequest { req, refs }))
}
107 changes: 75 additions & 32 deletions src/api/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use async_openai::{
};
use futures_util::{stream, StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_with::{base64::Base64, serde_as};

use crate::{
Expand Down Expand Up @@ -76,18 +75,29 @@ pub struct AskMeta {
pub quota: Option<AskLimit>,
}

#[derive(Serialize)]
pub struct CachedChunkDelta {
#[derive(Serialize, Default)]
pub struct GeneratedChunkDelta {
pub content: String,
}

#[derive(Serialize)]
pub struct CachedChunkChoice {
pub delta: CachedChunkDelta,
#[derive(Serialize, Default)]
pub struct GeneratedChunkChoice {
pub delta: GeneratedChunkDelta,
pub finish_reason: Option<String>,
}
#[derive(Serialize)]
pub struct CachedChunk {
pub choices: Vec<CachedChunkChoice>,
pub struct GeneratedChunk {
pub choices: Vec<GeneratedChunkChoice>,
pub id: i64,
}

impl Default for GeneratedChunk {
fn default() -> Self {
Self {
choices: Default::default(),
id: 1,
}
}
}

#[serde_as]
Expand All @@ -102,14 +112,16 @@ pub struct ExplainInitial {
initial: ExplainInitialData,
}

impl From<&str> for CachedChunk {
impl From<&str> for GeneratedChunk {
fn from(content: &str) -> Self {
CachedChunk {
choices: vec![CachedChunkChoice {
delta: CachedChunkDelta {
GeneratedChunk {
choices: vec![GeneratedChunkChoice {
delta: GeneratedChunkDelta {
content: content.into(),
},
..Default::default()
}],
..Default::default()
}
}
}
Expand All @@ -133,7 +145,7 @@ pub async fn ask(
supabase_pool: Data<Option<SupaPool>>,
diesel_pool: Data<Pool>,
messages: Json<ChatRequestMessages>,
) -> Result<Either<impl Responder, HttpResponse>, ApiError> {
) -> Result<Either<impl Responder, impl Responder>, ApiError> {
let mut conn = diesel_pool.get()?;
let user = get_user(&mut conn, user_id.id().unwrap())?;
let current = if user.is_subscriber() {
Expand All @@ -142,30 +154,61 @@ pub async fn ask(
} else {
let current = create_or_increment_limit(&mut conn, &user)?;
if current.is_none() {
return Ok(Either::Right(HttpResponse::Ok().json(json!(null))));
return Err(ApiError::PaymentRequired);
}
current
};
if let (Some(client), Some(pool)) = (&**openai_client, &**supabase_pool) {
let ask_req = prepare_ask_req(client, pool, messages.into_inner().messages).await?;
// 1. Prepare messages
let stream = client.chat().create_stream(ask_req.req).await.unwrap();
match prepare_ask_req(client, pool, messages.into_inner().messages).await? {
Some(ask_req) => {
// 1. Prepare messages
let stream = client.chat().create_stream(ask_req.req).await.unwrap();

let refs = stream::once(async move {
Ok(sse::Event::Data(
sse::Data::new_json(AskMeta {
typ: MetaType::Metadata,
sources: ask_req.refs,
quota: current.map(AskLimit::from_count),
})
.map_err(OpenAIError::JSONDeserialize)?,
))
});
return Ok(Either::Left(sse::Sse::from_stream(refs.chain(
stream.map_ok(|res| sse::Event::Data(sse::Data::new_json(res).unwrap())),
))));
let refs = stream::once(async move {
Ok(sse::Event::Data(
sse::Data::new_json(AskMeta {
typ: MetaType::Metadata,
sources: ask_req.refs,
quota: current.map(AskLimit::from_count),
})
.map_err(OpenAIError::JSONDeserialize)?,
))
});
let res = sse::Sse::from_stream(refs.chain(
stream.map_ok(|res| sse::Event::Data(sse::Data::new_json(res).unwrap())),
));
return Ok(Either::Left(res));
}
None => {
let parts = vec![
sse::Data::new_json(AskMeta {
typ: MetaType::Metadata,
sources: vec![],
quota: current.map(AskLimit::from_count),
})
.map_err(OpenAIError::JSONDeserialize)?,
sse::Data::new_json(GeneratedChunk::from(
"Sorry, I don't know how to help with that.",
))
.map_err(OpenAIError::JSONDeserialize)?,
sse::Data::new_json(GeneratedChunk {
choices: vec![GeneratedChunkChoice {
finish_reason: Some("stop".to_owned()),
..Default::default()
}],
..Default::default()
})
.map_err(OpenAIError::JSONDeserialize)?,
];
let stream = futures::stream::iter(parts.into_iter());
let res =
sse::Sse::from_stream(stream.map(|r| Ok::<_, ApiError>(sse::Event::Data(r))));

return Ok(Either::Right(res));
}
}
}
Ok(Either::Right(HttpResponse::NotImplemented().finish()))
Err(ApiError::NotImplemented)
}

pub async fn explain_feedback(
Expand Down Expand Up @@ -205,7 +248,7 @@ pub async fn explain(
initial: ExplainInitialData { cached: true, hash },
})
.map_err(OpenAIError::JSONDeserialize)?,
sse::Data::new_json(CachedChunk::from(explanation.as_str()))
sse::Data::new_json(GeneratedChunk::from(explanation.as_str()))
.map_err(OpenAIError::JSONDeserialize)?,
];
let stream = futures::stream::iter(parts.into_iter());
Expand Down
8 changes: 8 additions & 0 deletions src/api/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ pub enum ApiError {
PlaygroundError(#[from] PlaygroundError),
#[error("Unknown error: {0}")]
Generic(String),
#[error("Payment required")]
PaymentRequired,
#[error("Not implemented")]
NotImplemented,
}

impl ApiError {
Expand Down Expand Up @@ -139,6 +143,8 @@ impl ApiError {
Self::LoginRequiredForFeature(_) => "Login Required",
Self::OpenAIError(_) => "Open AI error",
Self::AIError(_) => "AI error",
Self::PaymentRequired => "Payment required",
Self::NotImplemented => "Not implemented",
}
}
}
Expand All @@ -164,6 +170,8 @@ impl ResponseError for ApiError {
Self::ValidationError(_) => StatusCode::BAD_REQUEST,
Self::MultipleCollectionSubscriptionLimitReached => StatusCode::BAD_REQUEST,
Self::LoginRequiredForFeature(_) => StatusCode::UNAUTHORIZED,
Self::PaymentRequired => StatusCode::PAYMENT_REQUIRED,
Self::NotImplemented => StatusCode::NOT_IMPLEMENTED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/api/ai_explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async fn test_explain() -> Result<(), Error> {

assert!(explain.status().is_success());

let expected = "data: {\"initial\":{\"cached\":true,\"hash\":\"nW77myAksS9XEAZpmXYHPFbW3WZTQvZLLO1cAwPTKwQ=\"}}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\"Explain this!\"}}]}\n\n";
let expected = "data: {\"initial\":{\"cached\":true,\"hash\":\"nW77myAksS9XEAZpmXYHPFbW3WZTQvZLLO1cAwPTKwQ=\"}}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\"Explain this!\"},\"finish_reason\":null}],\"id\":1}\n\n";
assert_eq!(
expected,
String::from_utf8_lossy(test::read_body(explain).await.as_ref())
Expand Down
4 changes: 2 additions & 2 deletions tests/api/ai_help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async fn test_quota() -> Result<(), Error> {
}))),
)
.await;
assert_ok_with_json_containing(ask, json!(null)).await;
assert_eq!(ask.status(), StatusCode::PAYMENT_REQUIRED);
drop(stubr);
Ok(())
}
Expand Down Expand Up @@ -111,7 +111,7 @@ async fn test_quota_rest() -> Result<(), Error> {
}))),
)
.await;
assert_ok_with_json_containing(ask, json!(null)).await;
assert_eq!(ask.status(), StatusCode::PAYMENT_REQUIRED);

sleep(Duration::from_secs(
SETTINGS
Expand Down

0 comments on commit 5f9bb64

Please sign in to comment.