Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SetStatus #248

Merged
merged 2 commits into from
Apr 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Added

- Add `SetStatus` to override status codes.
- **cors**: Added `CorsLayer::very_permissive` which is like
`CorsLayer::permissive` except it (truly) allows credentials. This is made
possible by mirroring the request's origin as well as method and headers
back as CORS-whitelisted ones
* **cors**: Allow customizing the value(s) for the `Vary` header
- **cors**: Allow customizing the value(s) for the `Vary` header

## Changed

Expand Down
2 changes: 2 additions & 0 deletions tower-http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ full = [
"request-id",
"sensitive-headers",
"set-header",
"set-status",
"trace",
"util",
]
Expand All @@ -87,6 +88,7 @@ redirect = []
request-id = []
sensitive-headers = []
set-header = []
set-status = []
trace = ["tracing"]
util = ["tower"]

Expand Down
4 changes: 4 additions & 0 deletions tower-http/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ pub mod request_id;
#[cfg_attr(docsrs, doc(cfg(feature = "catch-panic")))]
pub mod catch_panic;

#[cfg(feature = "set-status")]
#[cfg_attr(docsrs, doc(cfg(feature = "set-status")))]
pub mod set_status;

pub mod classify;
pub mod services;

Expand Down
141 changes: 141 additions & 0 deletions tower-http/src/set_status.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
//! Middleware to override status codes.
//!
//! # Example
//!
//! ```
//! use tower_http::set_status::SetStatusLayer;
//! use http::{Request, Response, StatusCode};
//! use hyper::Body;
//! use std::{iter::once, convert::Infallible};
//! use tower::{ServiceBuilder, Service, ServiceExt};
//!
//! async fn handle(req: Request<Body>) -> Result<Response<Body>, Infallible> {
//! // ...
//! # Ok(Response::new(Body::empty()))
//! }
//!
//! # #[tokio::main]
//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! async fn handle(req: Request<Body>) -> Result<Response<Body>, Infallible> {
//! // ...
//! # Ok(Response::new(Body::empty()))
//! }
//!
//! let mut service = ServiceBuilder::new()
//! // change the status to `404 Not Found` regardless what the inner service returns
//! .layer(SetStatusLayer::new(StatusCode::NOT_FOUND))
//! .service_fn(handle);
//!
//! // Call the service.
//! let request = Request::builder().body(Body::empty())?;
//!
//! let response = service.ready().await?.call(request).await?;
//!
//! assert_eq!(response.status(), StatusCode::NOT_FOUND);
//! #
//! # Ok(())
//! # }
//! ```

use http::{Request, Response, StatusCode};
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;

/// Layer that applies [`SetStatus`] which overrides the status codes.
#[derive(Debug, Clone, Copy)]
pub struct SetStatusLayer {
status: StatusCode,
}

impl SetStatusLayer {
/// Create a new [`SetStatusLayer`].
///
/// The response status code will be `status` regardless of what the inner service returns.
pub fn new(status: StatusCode) -> Self {
SetStatusLayer { status }
}
}

impl<S> Layer<S> for SetStatusLayer {
type Service = SetStatus<S>;

fn layer(&self, inner: S) -> Self::Service {
SetStatus::new(inner, self.status)
}
}

/// Middleware to override status codes.
///
/// See the [module docs](self) for more details.
#[derive(Debug, Clone, Copy)]
pub struct SetStatus<S> {
inner: S,
status: StatusCode,
}

impl<S> SetStatus<S> {
/// Create a new [`SetStatus`].
///
/// The response status code will be `status` regardless of what the inner service returns.
pub fn new(inner: S, status: StatusCode) -> Self {
Self { status, inner }
}

define_inner_service_accessors!();

/// Returns a new [`Layer`] that wraps services with a `SetStatus` middleware.
///
/// [`Layer`]: tower_layer::Layer
pub fn layer(status: StatusCode) -> SetStatusLayer {
SetStatusLayer::new(status)
}
}

impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SetStatus<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
ResponseFuture {
inner: self.inner.call(req),
status: Some(self.status),
}
}
}

pin_project! {
/// Response future for [`SetStatus`].
pub struct ResponseFuture<F> {
#[pin]
inner: F,
status: Option<StatusCode>,
}
}

impl<F, B, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut response = futures_core::ready!(this.inner.poll(cx)?);
*response.status_mut() = this.status.take().expect("future polled after completion");
Poll::Ready(Ok(response))
}
}