diff --git a/src/ai/error.rs b/src/ai/error.rs index 909c805b..a1e6d50b 100644 --- a/src/ai/error.rs +++ b/src/ai/error.rs @@ -1,6 +1,10 @@ +use actix_http::StatusCode; +use actix_web::{HttpResponse, ResponseError}; use async_openai::error::OpenAIError; use thiserror::Error; +use crate::error::ErrorResponse; + #[derive(Error, Debug)] pub enum AIError { #[error("OpenAI error: {0}")] @@ -16,3 +20,26 @@ pub enum AIError { #[error("Tiktoken Error: {0}")] TiktokenError(#[from] anyhow::Error), } + +impl ResponseError for AIError { + fn status_code(&self) -> StatusCode { + match &self { + AIError::OpenAIError(_) | AIError::SqlXError(_) | AIError::TiktokenError(_) => { + StatusCode::INTERNAL_SERVER_ERROR + } + AIError::FlaggedError | AIError::NoUserPrompt | AIError::TokenLimit => { + StatusCode::BAD_REQUEST + } + } + } + + fn error_response(&self) -> HttpResponse { + let status_code = self.status_code(); + let mut builder = HttpResponse::build(status_code); + builder.json(ErrorResponse { + code: status_code.as_u16(), + message: status_code.canonical_reason().unwrap_or("Unknown"), + error: "AI Error", + }) + } +} diff --git a/src/api/error.rs b/src/api/error.rs index 3ef1ccab..143e4b9d 100644 --- a/src/api/error.rs +++ b/src/api/error.rs @@ -2,6 +2,7 @@ use std::string::FromUtf8Error; use crate::ai::error::AIError; use crate::db::error::DbError; +use crate::error::ErrorResponse; use actix_http::header::HeaderValue; use actix_web::http::header::HeaderName; @@ -10,7 +11,6 @@ use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers}; use actix_web::{HttpResponse, ResponseError}; use async_openai::error::OpenAIError; use basket::BasketError; -use serde::Serialize; use serde_json::json; use thiserror::Error; use uuid::Uuid; @@ -63,6 +63,30 @@ pub enum PlaygroundError { SettingsError, } +impl ResponseError for PlaygroundError { + fn status_code(&self) -> StatusCode { + match &self { + PlaygroundError::CryptError(_) + | PlaygroundError::DecodeError(_) + | PlaygroundError::NoNonceError + | PlaygroundError::UtfDecodeError(_) => StatusCode::BAD_REQUEST, + PlaygroundError::OctocrabError(_) | PlaygroundError::SettingsError => { + StatusCode::INTERNAL_SERVER_ERROR + } + } + } + + fn error_response(&self) -> HttpResponse { + let status_code = self.status_code(); + let mut builder = HttpResponse::build(status_code); + builder.json(ErrorResponse { + code: status_code.as_u16(), + message: status_code.canonical_reason().unwrap_or("Unknown"), + error: "Playground Error", + }) + } +} + #[derive(Error, Debug)] pub enum ApiError { #[error("Artificial error")] @@ -149,13 +173,6 @@ impl ApiError { } } -#[derive(Serialize)] -struct ErrorResponse<'a> { - code: u16, - error: &'a str, - message: &'a str, -} - impl ResponseError for ApiError { fn status_code(&self) -> StatusCode { match *self { @@ -172,6 +189,8 @@ impl ResponseError for ApiError { Self::LoginRequiredForFeature(_) => StatusCode::UNAUTHORIZED, Self::PaymentRequired => StatusCode::PAYMENT_REQUIRED, Self::NotImplemented => StatusCode::NOT_IMPLEMENTED, + Self::PlaygroundError(ref e) => e.status_code(), + Self::AIError(ref e) => e.status_code(), _ => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -210,6 +229,8 @@ impl ResponseError for ApiError { message: format!("Please login to use feature: {0}", feature).as_str(), error: self.name(), }), + ApiError::PlaygroundError(error) => error.error_response(), + ApiError::AIError(error) => error.error_response(), _ if status_code == StatusCode::INTERNAL_SERVER_ERROR => builder.json(ErrorResponse { code: status_code.as_u16(), message: "internal server error", diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..4ed1b2ac --- /dev/null +++ b/src/error.rs @@ -0,0 +1,8 @@ +use serde::Serialize; + +#[derive(Serialize)] +pub struct ErrorResponse<'a> { + pub code: u16, + pub error: &'a str, + pub message: &'a str, +} diff --git a/src/lib.rs b/src/lib.rs index a182a3fc..97dce123 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ use actix_web::{ pub mod ai; pub mod api; pub mod db; +pub mod error; pub mod fxa; mod helpers; pub mod ids; diff --git a/tests/api/play.rs b/tests/api/play.rs index 86e955d3..74f936f7 100644 --- a/tests/api/play.rs +++ b/tests/api/play.rs @@ -74,7 +74,7 @@ async fn test_invalid_id() -> Result<(), Error> { let service = test::init_service(app).await; let mut client = TestHttpClient::new(service).await; let res = client.get("/api/v1/play/sssieddidxsx", None).await; - // This used to panic, now it should just 500 - assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + // This used to panic, now it should just 400 + assert_eq!(res.status(), StatusCode::BAD_REQUEST); Ok(()) }