Skip to content

Commit

Permalink
Add SetStatus (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored Apr 24, 2022
1 parent 9299ba5 commit fd13e2a
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Added

- Add `SetStatus` to override status codes.
- **cors**: Added `CorsLayer::very_permissive` which is like
`CorsLayer::permissive` except it (truly) allows credentials. This is made
possible by mirroring the request's origin as well as method and headers
back as CORS-whitelisted ones
* **cors**: Allow customizing the value(s) for the `Vary` header
- **cors**: Allow customizing the value(s) for the `Vary` header

## Changed

Expand Down
2 changes: 2 additions & 0 deletions tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ full = [
"request-id",
"sensitive-headers",
"set-header",
"set-status",
"trace",
"util",
]
Expand All @@ -87,6 +88,7 @@ redirect = []
request-id = []
sensitive-headers = []
set-header = []
set-status = []
trace = ["tracing"]
util = ["tower"]

Expand Down
4 changes: 4 additions & 0 deletions tower-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ pub mod request_id;
#[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))]
pub mod catch_panic;

#[cfg(feature = "set-status")]
#[cfg_attr(docsrs, doc(cfg(feature = "set-status")))]
pub mod set_status;

pub mod classify;
pub mod services;

Expand Down
141 changes: 141 additions & 0 deletions tower-http/src/set_status.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
//! Middleware to override status codes.
//!
//! # Example
//!
//! ```
//! use tower_http::set_status::SetStatusLayer;
//! use http::{Request, Response, StatusCode};
//! use hyper::Body;
//! use std::{iter::once, convert::Infallible};
//! use tower::{ServiceBuilder, Service, ServiceExt};
//!
//! async fn handle(req: Request<Body>) -> Result<Response<Body>, Infallible> {
//! // ...
//! # Ok(Response::new(Body::empty()))
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! async fn handle(req: Request<Body>) -> Result<Response<Body>, Infallible> {
//! // ...
//! # Ok(Response::new(Body::empty()))
//! }
//!
//! let mut service = ServiceBuilder::new()
//! // change the status to `404 Not Found` regardless what the inner service returns
//! .layer(SetStatusLayer::new(StatusCode::NOT_FOUND))
//! .service_fn(handle);
//!
//! // Call the service.
//! let request = Request::builder().body(Body::empty())?;
//!
//! let response = service.ready().await?.call(request).await?;
//!
//! assert_eq!(response.status(), StatusCode::NOT_FOUND);
//! #
//! # Ok(())
//! # }
//! ```
use http::{Request, Response, StatusCode};
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;

/// Layer that applies [`SetStatus`] which overrides the status codes.
#[derive(Debug, Clone, Copy)]
pub struct SetStatusLayer {
status: StatusCode,
}

impl SetStatusLayer {
/// Create a new [`SetStatusLayer`].
///
/// The response status code will be `status` regardless of what the inner service returns.
pub fn new(status: StatusCode) -> Self {
SetStatusLayer { status }
}
}

impl<S> Layer<S> for SetStatusLayer {
type Service = SetStatus<S>;

fn layer(&self, inner: S) -> Self::Service {
SetStatus::new(inner, self.status)
}
}

/// Middleware to override status codes.
///
/// See the [module docs](self) for more details.
#[derive(Debug, Clone, Copy)]
pub struct SetStatus<S> {
inner: S,
status: StatusCode,
}

impl<S> SetStatus<S> {
/// Create a new [`SetStatus`].
///
/// The response status code will be `status` regardless of what the inner service returns.
pub fn new(inner: S, status: StatusCode) -> Self {
Self { status, inner }
}

define_inner_service_accessors!();

/// Returns a new [`Layer`] that wraps services with a `SetStatus` middleware.
///
/// [`Layer`]: tower_layer::Layer
pub fn layer(status: StatusCode) -> SetStatusLayer {
SetStatusLayer::new(status)
}
}

impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SetStatus<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
ResponseFuture {
inner: self.inner.call(req),
status: Some(self.status),
}
}
}

pin_project! {
/// Response future for [`SetStatus`].
pub struct ResponseFuture<F> {
#[pin]
inner: F,
status: Option<StatusCode>,
}
}

impl<F, B, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut response = futures_core::ready!(this.inner.poll(cx)?);
*response.status_mut() = this.status.take().expect("future polled after completion");
Poll::Ready(Ok(response))
}
}

0 comments on commit fd13e2a

Please sign in to comment.