diff --git a/Cargo.toml b/Cargo.toml index 156fa711c3..7c73c50117 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ "axum", + "axum-handle-error-extract", "examples/*", ] diff --git a/axum-handle-error-extract/CHANGELOG.md b/axum-handle-error-extract/CHANGELOG.md new file mode 100644 index 0000000000..33b41831a2 --- /dev/null +++ b/axum-handle-error-extract/CHANGELOG.md @@ -0,0 +1,14 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +# Unreleased + +- None + +# 0.1.0 (05. November, 2021) + +- Initial release. diff --git a/axum-handle-error-extract/Cargo.toml b/axum-handle-error-extract/Cargo.toml new file mode 100644 index 0000000000..5df894c4cc --- /dev/null +++ b/axum-handle-error-extract/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "axum-handle-error-extract" +version = "0.1.0" +authors = ["David Pedersen "] +categories = ["asynchronous", "network-programming", "web-programming"] +description = "Error handling layer for axum that supports extractors and async functions" +edition = "2018" +homepage = "https://github.com/tokio-rs/axum" +keywords = ["http", "web", "framework"] +license = "MIT" +readme = "README.md" +repository = "https://github.com/tokio-rs/axum" + +[dependencies] +axum = { path = "../axum" } +tower-service = "0.3" +tower-layer = "0.3" +tower = { version = "0.4", features = ["util"] } +pin-project-lite = "0.2" + +[dev-dependencies] +tower = { version = "0.4", features = ["util", "timeout"] } diff --git a/axum-handle-error-extract/LICENSE b/axum-handle-error-extract/LICENSE new file mode 100644 index 0000000000..b980cacc77 --- /dev/null +++ b/axum-handle-error-extract/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2019 Tower Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/axum-handle-error-extract/README.md b/axum-handle-error-extract/README.md new file mode 100644 index 0000000000..60cee5540e --- /dev/null +++ b/axum-handle-error-extract/README.md @@ -0,0 +1,45 @@ +# axum-handle-error-extract + +Error handling layer for axum that supports extractors and async functions + +[![Build status](https://github.com/tokio-rs/axum-handle-error-extract/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum-handle-error-extract/actions/workflows/CI.yml) +[![Crates.io](https://img.shields.io/crates/v/axum-handle-error-extract)](https://crates.io/crates/axum-handle-error-extract) +[![Documentation](https://docs.rs/axum-handle-error-extract/badge.svg)](https://docs.rs/axum) + +More information about this crate can be found in the [crate documentation][docs]. + +## Safety + +This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in +100% safe Rust. + +## Minimum supported Rust version + +axum-handle-error-extract's MSRV is 1.54. + +## Getting Help + +You're also welcome to ask in the [Discord channel][chat] or open an [issue] +with your question. + +## Contributing + +:balloon: Thanks for your help improving the project! We are so happy to have +you! We have a [contributing guide][contributing] to help you get involved in the +`axum` project. + +## License + +This project is licensed under the [MIT license][license]. + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in `axum` by you, shall be licensed as MIT, without any +additional terms or conditions. + +[docs]: https://docs.rs/axum-handle-error-extract +[contributing]: /CONTRIBUTING.md +[chat]: https://discord.gg/tokio +[issue]: https://github.com/tokio-rs/axum/issues/new +[license]: /axum/LICENSE diff --git a/axum-handle-error-extract/src/lib.rs b/axum-handle-error-extract/src/lib.rs new file mode 100644 index 0000000000..aae7979dec --- /dev/null +++ b/axum-handle-error-extract/src/lib.rs @@ -0,0 +1,389 @@ +//! Error handling layer for axum that supports extractors and async functions. +//! +//! This crate provides [`HandleErrorLayer`] which works similarly to +//! [`axum::error_handling::HandleErrorLayer`] except that it supports +//! extractors and async functions: +//! +//! ```rust +//! use axum::{ +//! Router, +//! BoxError, +//! response::IntoResponse, +//! http::{StatusCode, Method, Uri}, +//! routing::get, +//! }; +//! use tower::{ServiceBuilder, timeout::error::Elapsed}; +//! use std::time::Duration; +//! use axum_handle_error_extract::HandleErrorLayer; +//! +//! let app = Router::new() +//! .route("/", get(|| async {})) +//! .layer( +//! ServiceBuilder::new() +//! // timeouts produces errors, so we handle those with `handle_error` +//! .layer(HandleErrorLayer::new(handle_error)) +//! .timeout(Duration::from_secs(10)) +//! ); +//! +//! // our handler take can 0 to 16 extractors and the final argument must +//! // always be the error produced by the middleware +//! async fn handle_error( +//! method: Method, +//! uri: Uri, +//! error: BoxError, +//! ) -> impl IntoResponse { +//! if error.is::() { +//! ( +//! StatusCode::REQUEST_TIMEOUT, +//! format!("{} {} took too long", method, uri), +//! ) +//! } else { +//! ( +//! StatusCode::INTERNAL_SERVER_ERROR, +//! format!("{} {} failed: {}", method, uri, error), +//! ) +//! } +//! } +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! Not running any extractors is also supported: +//! +//! ```rust +//! use axum::{ +//! Router, +//! BoxError, +//! response::IntoResponse, +//! http::StatusCode, +//! routing::get, +//! }; +//! use tower::{ServiceBuilder, timeout::error::Elapsed}; +//! use std::time::Duration; +//! use axum_handle_error_extract::HandleErrorLayer; +//! +//! let app = Router::new() +//! .route("/", get(|| async {})) +//! .layer( +//! ServiceBuilder::new() +//! .layer(HandleErrorLayer::new(handle_error)) +//! .timeout(Duration::from_secs(10)) +//! ); +//! +//! // this function just takes the error +//! async fn handle_error(error: BoxError) -> impl IntoResponse { +//! if error.is::() { +//! ( +//! StatusCode::REQUEST_TIMEOUT, +//! "Request timeout".to_string(), +//! ) +//! } else { +//! ( +//! StatusCode::INTERNAL_SERVER_ERROR, +//! format!("Unhandled internal error: {}", error), +//! ) +//! } +//! } +//! # async { +//! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); +//! # }; +//! ``` +//! +//! See [`axum::error_handling`] for more details on axum's error handling model and +//! [`axum::extract`] for more details on extractors. +//! +//! # The future +//! +//! In axum 0.4 this will replace the current [`axum::error_handling::HandleErrorLayer`]. + +#![warn( + clippy::all, + clippy::dbg_macro, + clippy::todo, + clippy::empty_enum, + clippy::enum_glob_use, + clippy::mem_forget, + clippy::unused_self, + clippy::filter_map_next, + clippy::needless_continue, + clippy::needless_borrow, + clippy::match_wildcard_for_single_variants, + clippy::if_let_mutex, + clippy::mismatched_target_os, + clippy::await_holding_lock, + clippy::match_on_vec_items, + clippy::imprecise_flops, + clippy::suboptimal_flops, + clippy::lossy_float_literal, + clippy::rest_pat_in_fully_bound_structs, + clippy::fn_params_excessive_bools, + clippy::exit, + clippy::inefficient_to_string, + clippy::linkedlist, + clippy::macro_use_imports, + clippy::option_option, + clippy::verbose_file_reads, + clippy::unnested_or_patterns, + rust_2018_idioms, + future_incompatible, + nonstandard_style, + missing_debug_implementations, + missing_docs +)] +#![deny(unreachable_pub, private_in_public)] +#![allow(elided_lifetimes_in_paths, clippy::type_complexity)] +#![forbid(unsafe_code)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(test, allow(clippy::float_cmp))] + +use axum::{ + body::{box_body, BoxBody, Bytes, Full, HttpBody}, + extract::{FromRequest, RequestParts}, + http::{Request, Response, StatusCode}, + response::IntoResponse, + BoxError, +}; +use pin_project_lite::pin_project; +use std::{ + convert::Infallible, + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; +use tower::ServiceExt; +use tower_layer::Layer; +use tower_service::Service; + +/// [`Layer`] that applies [`HandleError`] which is a [`Service`] adapter +/// that handles errors by converting them into responses. +/// +/// See [module docs](self) for more details on axum's error handling model. +pub struct HandleErrorLayer { + f: F, + _extractor: PhantomData T>, +} + +impl HandleErrorLayer { + /// Create a new `HandleErrorLayer`. + pub fn new(f: F) -> Self { + Self { + f, + _extractor: PhantomData, + } + } +} + +impl Clone for HandleErrorLayer +where + F: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + _extractor: PhantomData, + } + } +} + +impl fmt::Debug for HandleErrorLayer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HandleErrorLayer") + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Layer for HandleErrorLayer +where + F: Clone, +{ + type Service = HandleError; + + fn layer(&self, inner: S) -> Self::Service { + HandleError::new(inner, self.f.clone()) + } +} + +/// A [`Service`] adapter that handles errors by converting them into responses. +/// +/// See [module docs](self) for more details on axum's error handling model. +pub struct HandleError { + inner: S, + f: F, + _extractor: PhantomData T>, +} + +impl HandleError { + /// Create a new `HandleError`. + pub fn new(inner: S, f: F) -> Self { + Self { + inner, + f, + _extractor: PhantomData, + } + } +} + +impl Clone for HandleError +where + S: Clone, + F: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + f: self.f.clone(), + _extractor: PhantomData, + } + } +} + +impl fmt::Debug for HandleError +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HandleError") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Service> for HandleError +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Error: Send, + S::Future: Send, + F: FnOnce(S::Error) -> Fut + Clone + Send + 'static, + Fut: Future + Send, + Res: IntoResponse, + ReqBody: Send + 'static, + ResBody: HttpBody + Send + 'static, + ResBody::Error: Into, +{ + type Response = Response; + type Error = Infallible; + type Future = ResponseFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let f = self.f.clone(); + + let clone = self.inner.clone(); + let inner = std::mem::replace(&mut self.inner, clone); + + let future = Box::pin(async move { + match inner.oneshot(req).await { + Ok(res) => Ok(res.map(box_body)), + Err(err) => Ok(f(err).await.into_response().map(box_body)), + } + }); + + ResponseFuture { future } + } +} + +#[allow(unused_macros)] +macro_rules! impl_service { + ( $($ty:ident),* $(,)? ) => { + impl Service> + for HandleError + where + S: Service, Response = Response> + Clone + Send + 'static, + S::Error: Send, + S::Future: Send, + F: FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, + Fut: Future + Send, + Res: IntoResponse, + $( $ty: FromRequest + Send,)* + ReqBody: Send + 'static, + ResBody: HttpBody + Send + 'static, + ResBody::Error: Into, + { + type Response = Response; + type Error = Infallible; + + type Future = ResponseFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[allow(non_snake_case)] + fn call(&mut self, req: Request) -> Self::Future { + let f = self.f.clone(); + + let clone = self.inner.clone(); + let inner = std::mem::replace(&mut self.inner, clone); + + let future = Box::pin(async move { + let mut req = RequestParts::new(req); + + $( + let $ty = match $ty::from_request(&mut req).await { + Ok(value) => value, + Err(rejection) => return Ok(rejection.into_response().map(box_body)), + }; + )* + + let req = match req.try_into_request() { + Ok(req) => req, + Err(err) => { + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(box_body(Full::from(err.to_string()))) + .unwrap()); + } + }; + + match inner.oneshot(req).await { + Ok(res) => Ok(res.map(box_body)), + Err(err) => Ok(f($($ty),*, err).await.into_response().map(box_body)), + } + }); + + ResponseFuture { future } + } + } + } +} + +impl_service!(T1); +impl_service!(T1, T2); +impl_service!(T1, T2, T3); +impl_service!(T1, T2, T3, T4); +impl_service!(T1, T2, T3, T4, T5); +impl_service!(T1, T2, T3, T4, T5, T6); +impl_service!(T1, T2, T3, T4, T5, T6, T7); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15); +impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16); + +pin_project! { + /// Response future for [`HandleError`]. + pub struct ResponseFuture { + #[pin] + future: Pin, Infallible>> + Send + 'static>>, + } +} + +impl Future for ResponseFuture { + type Output = Result, Infallible>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().future.poll(cx) + } +}