diff --git a/client/http-client/src/client.rs b/client/http-client/src/client.rs index d4aa1b3e2f..e20e571fcd 100644 --- a/client/http-client/src/client.rs +++ b/client/http-client/src/client.rs @@ -359,7 +359,7 @@ where let mut failed_calls = 0; for _ in 0..json_rps.len() { - responses.push(Err(ErrorObject::borrowed(0, &"", None))); + responses.push(Err(ErrorObject::borrowed(0, "", None))); } for rp in json_rps { diff --git a/core/Cargo.toml b/core/Cargo.toml index d1ac9d0d4d..d1f7107598 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -34,15 +34,13 @@ parking_lot = { version = "0.12", optional = true } tokio = { version = "1.16", optional = true } wasm-bindgen-futures = { version = "0.4.19", optional = true } futures-timer = { version = "3", optional = true } -route-recognizer = { version = "0.3.1", optional = true } -http = { version = "0.2.9", optional = true } + [features] default = [] http-helpers = ["hyper", "futures-util"] server = [ "futures-util/alloc", - "route-recognizer", "rustc-hash/std", "parking_lot", "rand", @@ -50,7 +48,6 @@ server = [ "tokio/sync", "tokio/macros", "tokio/time", - "http", ] client = ["futures-util/sink", "tokio/sync"] async-client = [ diff --git a/core/src/client/async_client/helpers.rs b/core/src/client/async_client/helpers.rs index 0a0509c7d1..e4ededea66 100644 --- a/core/src/client/async_client/helpers.rs +++ b/core/src/client/async_client/helpers.rs @@ -69,7 +69,7 @@ pub(crate) fn process_batch_response( }; for _ in range { - let err_obj = ErrorObject::borrowed(0, &"", None); + let err_obj = ErrorObject::borrowed(0, "", None); responses.push(Err(err_obj)); } diff --git a/core/src/server/helpers.rs b/core/src/server/helpers.rs index 493874144d..15183b9f84 100644 --- a/core/src/server/helpers.rs +++ b/core/src/server/helpers.rs @@ -256,7 +256,7 @@ impl MethodResponse { let err = ResponsePayload::error_borrowed(ErrorObject::borrowed( err_code, - &OVERSIZED_RESPONSE_MSG, + OVERSIZED_RESPONSE_MSG, data.as_deref(), )); let result = diff --git a/core/src/server/mod.rs b/core/src/server/mod.rs index 61cb62698f..9a05d29cc8 100644 --- a/core/src/server/mod.rs +++ b/core/src/server/mod.rs @@ -30,8 +30,6 @@ mod error; /// Helpers. pub mod helpers; -/// Host filtering. -mod host_filtering; /// JSON-RPC "modules" group sets of methods that belong together and handles method/subscription registration. mod rpc_module; /// Subscription related types. @@ -39,7 +37,6 @@ mod subscription; pub use error::*; pub use helpers::{BatchResponseBuilder, BoundedWriter, MethodResponse, MethodSink}; -pub use host_filtering::*; pub use rpc_module::*; pub use subscription::*; diff --git a/examples/examples/cors_server.rs b/examples/examples/cors_server.rs index 04a91432bf..809b07b281 100644 --- a/examples/examples/cors_server.rs +++ b/examples/examples/cors_server.rs @@ -85,11 +85,7 @@ async fn run_server() -> anyhow::Result { // modifying requests / responses. These features are independent of one another // and can also be used separately. // In this example, we use both features. - let server = Server::builder() - .disable_host_filtering() - .set_middleware(middleware) - .build("127.0.0.1:0".parse::()?) - .await?; + let server = Server::builder().set_middleware(middleware).build("127.0.0.1:0".parse::()?).await?; let mut module = RpcModule::new(()); module.register_method("say_hello", |_, _| { diff --git a/examples/examples/host_filter_middleware.rs b/examples/examples/host_filter_middleware.rs new file mode 100644 index 0000000000..39ca7c84fc --- /dev/null +++ b/examples/examples/host_filter_middleware.rs @@ -0,0 +1,82 @@ +// Copyright 2019-2022 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! This example shows how to configure `host filtering` by tower middleware on the jsonrpsee server. +//! +//! The server whitelist's only `example.com` and any call from localhost will be +//! rejected both by HTTP and WebSocket transports. + +use std::net::SocketAddr; + +use jsonrpsee::core::client::ClientT; +use jsonrpsee::http_client::HttpClientBuilder; +use jsonrpsee::rpc_params; +use jsonrpsee::server::middleware::HostFilterLayer; +use jsonrpsee::server::{RpcModule, Server}; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init() + .expect("setting default subscriber failed"); + + let addr = run_server().await?; + let url = format!("http://{}", addr); + + // Use RPC client to get the response of `say_hello` method. + let client = HttpClientBuilder::default().build(&url)?; + // This call will be denied because only `example.com` URIs/hosts are allowed by the host filter. + let response = client.request::("say_hello", rpc_params![]).await.unwrap_err(); + println!("[main]: response: {}", response); + + Ok(()) +} + +async fn run_server() -> anyhow::Result { + // Custom tower service to handle the RPC requests + let service_builder = tower::ServiceBuilder::new() + // For this example we only want to permit requests from `example.com` + // all other request are denied. + // + // `HostFilerLayer::new` only fails on invalid URIs.. + .layer(HostFilterLayer::new(["example.com"]).unwrap()); + + let server = Server::builder().set_middleware(service_builder).build("127.0.0.1:0".parse::()?).await?; + + let addr = server.local_addr()?; + + let mut module = RpcModule::new(()); + module.register_method("say_hello", |_, _| "lo").unwrap(); + + let handle = server.start(module); + + // In this example we don't care about doing shutdown so let's it run forever. + // You may use the `ServerHandle` to shut it down or manage it yourself. + tokio::spawn(handle.stopped()); + + Ok(addr) +} diff --git a/examples/examples/http_proxy_middleware.rs b/examples/examples/http_proxy_middleware.rs index 74b3eeb8a6..61686fd78e 100644 --- a/examples/examples/http_proxy_middleware.rs +++ b/examples/examples/http_proxy_middleware.rs @@ -44,7 +44,7 @@ use std::time::Duration; use jsonrpsee::core::client::ClientT; use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::rpc_params; -use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; +use jsonrpsee::server::middleware::ProxyGetRequestLayer; use jsonrpsee::server::{RpcModule, Server}; #[tokio::main] diff --git a/server/Cargo.toml b/server/Cargo.toml index c6209292c7..cc8c31c01d 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -26,6 +26,9 @@ tokio-util = { version = "0.7", features = ["compat"] } tokio-stream = "0.1.7" hyper = { version = "0.14", features = ["server", "http1", "http2"] } tower = "0.4.13" +route-recognizer = "0.3.1" +http = "0.2.9" +thiserror = "1.0.44" [dev-dependencies] anyhow = "1" diff --git a/core/src/server/host_filtering.rs b/server/src/middleware/authority.rs similarity index 62% rename from core/src/server/host_filtering.rs rename to server/src/middleware/authority.rs index 0270196c14..f10c3260da 100644 --- a/core/src/server/host_filtering.rs +++ b/server/src/middleware/authority.rs @@ -24,13 +24,36 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -//! HTTP Host Header validation. +//! Utility and types related to the authority of an URI. -use std::net::SocketAddr; - -use crate::Error; use http::uri::{InvalidUri, Uri}; -use route_recognizer::Router; +use hyper::{Body, Request}; +use jsonrpsee_core::http_helpers; + +/// Represent the http URI scheme that is returned by the HTTP host header +/// +/// Further information can be found: +#[derive(Clone, Hash, PartialEq, Eq, Debug)] +pub struct Authority { + /// The host. + pub host: String, + /// The port. + pub port: Port, +} + +/// Error that can happen when parsing an URI authority fails. +#[derive(Debug, thiserror::Error)] +pub enum AuthorityError { + /// Invalid URI. + #[error("{0}")] + InvalidUri(InvalidUri), + /// Invalid port. + #[error("{0}")] + InvalidPort(String), + /// The host was not found. + #[error("The host was not found")] + MissingHost, +} /// Port pattern #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)] @@ -43,27 +66,12 @@ pub enum Port { Fixed(u16), } -impl From for Port { - fn from(port: u16) -> Port { - Port::Fixed(port) - } -} - -/// Represent the http URI scheme that is returned by the HTTP host header -/// -/// Further information can be found: -#[derive(Clone, Hash, PartialEq, Eq, Debug)] -pub struct Authority { - hostname: String, - port: Port, -} - impl Authority { fn inner_from_str(value: &str) -> Result { let uri: Uri = value.parse().map_err(AuthorityError::InvalidUri)?; let authority = uri.authority().ok_or(AuthorityError::MissingHost)?; - let hostname = authority.host(); - let maybe_port = &authority.as_str()[hostname.len()..]; + let host = authority.host(); + let maybe_port = &authority.as_str()[host.len()..]; // After the host segment, the authority may contain a port such as `fooo:33`, `foo:*` or `foo` let port = match maybe_port.split_once(':') { @@ -81,24 +89,35 @@ impl Authority { None => Port::Default, }; - Ok(Self { hostname: hostname.to_string(), port }) + Ok(Self { host: host.to_owned(), port }) + } + + /// Attempts to parse the authority from a HTTP request. + /// + /// The `Authority` can be sent by the client in the `Host header` or in the `URI` + /// such that both must be checked. + pub fn from_http_request(request: &Request) -> Option { + // NOTE: we use our own `Authority type` here because an invalid port number would return `None` here + // and that should be denied. + let host_header = + http_helpers::read_header_value(request.headers(), hyper::header::HOST).map(Authority::try_from); + let uri = request.uri().authority().map(|v| Authority::try_from(v.as_str())); + + match (host_header, uri) { + (Some(Ok(a1)), Some(Ok(a2))) => { + if a1 == a2 { + Some(a1) + } else { + None + } + } + (Some(Ok(a)), _) => Some(a), + (_, Some(Ok(a))) => Some(a), + _ => None, + } } } -/// Error that can happen when parsing an URI authority fails. -#[derive(Debug, thiserror::Error)] -pub enum AuthorityError { - /// Invalid URI. - #[error("{0}")] - InvalidUri(InvalidUri), - /// Invalid port. - #[error("{0}")] - InvalidPort(String), - /// The host was not found. - #[error("The host was not found")] - MissingHost, -} - impl<'a> TryFrom<&'a str> for Authority { type Error = AuthorityError; @@ -118,69 +137,14 @@ impl TryFrom for Authority { impl TryFrom for Authority { type Error = AuthorityError; - fn try_from(sockaddr: SocketAddr) -> Result { + fn try_from(sockaddr: std::net::SocketAddr) -> Result { Self::inner_from_str(&sockaddr.to_string()) } } -/// Represent the URL patterns that is whitelisted. -#[derive(Default, Debug, Clone)] -pub struct WhitelistedHosts(Router); - -impl From for WhitelistedHosts -where - T: IntoIterator, -{ - fn from(value: T) -> Self { - let mut router = Router::new(); - - for auth in value.into_iter() { - router.add(&auth.hostname, auth.port); - } - - Self(router) - } -} - -impl WhitelistedHosts { - fn recognize(&self, other: &Authority) -> bool { - if let Ok(p) = self.0.recognize(&other.hostname) { - let p = p.handler(); - - match (p, &other.port) { - (Port::Any, _) => true, - (Port::Default, Port::Default) => true, - (Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true, - _ => false, - } - } else { - false - } - } -} - -/// Policy for validating the `HTTP host header`. -#[derive(Debug, Clone)] -pub enum AllowHosts { - /// Allow all hosts (no filter). - Any, - /// Allow only specified hosts. - Only(WhitelistedHosts), -} - -impl AllowHosts { - /// Verify a host. - pub fn verify(&self, value: &str) -> Result<(), Error> { - let auth = Authority::try_from(value) - .map_err(|_| Error::HttpHeaderRejected("host", format!("Invalid authority: {value}")))?; - - if let AllowHosts::Only(url_pat) = self { - if !url_pat.recognize(&auth) { - return Err(Error::HttpHeaderRejected("host", value.into())); - } - } - - Ok(()) +impl From for Port { + fn from(port: u16) -> Port { + Port::Fixed(port) } } @@ -195,10 +159,12 @@ fn default_port(scheme: Option<&str>) -> Option { #[cfg(test)] mod tests { - use super::{AllowHosts, Authority, Port}; + use super::{Authority, Port}; + use hyper::header::HOST; + use hyper::Body; fn authority(host: &str, port: Port) -> Authority { - Authority { hostname: host.to_owned(), port } + Authority { host: host.to_owned(), port } } #[test] @@ -229,51 +195,41 @@ mod tests { assert!(Authority::try_from("user:password").is_err()); assert!(Authority::try_from("parity.io/somepath").is_err()); assert!(Authority::try_from("127.0.0.1:8545/somepath").is_err()); + assert!(Authority::try_from("127.0.0.1:-1337").is_err()); } #[test] - fn should_allow_when_validation_is_disabled() { - assert!((AllowHosts::Any).verify("any").is_ok()); + fn authority_from_http_only_host_works() { + let req = hyper::Request::builder().header(HOST, "example.com").body(Body::empty()).unwrap(); + assert!(Authority::from_http_request(&req).is_some()); } #[test] - fn should_reject_if_header_not_on_the_list() { - assert!((AllowHosts::Only(vec![].into())).verify("parity.io").is_err()); + fn authority_only_uri_works() { + let req = hyper::Request::builder().uri("example.com").body(Body::empty()).unwrap(); + assert!(Authority::from_http_request(&req).is_some()); } #[test] - fn should_accept_if_on_the_list() { - assert!(AllowHosts::Only(vec![Authority::try_from("parity.io").unwrap()].into()).verify("parity.io").is_ok()); + fn authority_host_and_uri_works() { + let req = hyper::Request::builder() + .header(HOST, "example.com:9999") + .uri("example.com:9999") + .body(Body::empty()) + .unwrap(); + assert!(Authority::from_http_request(&req).is_some()); } #[test] - fn should_accept_if_on_the_list_with_port() { - assert!((AllowHosts::Only(vec![Authority::try_from("parity.io:443").unwrap()].into())) - .verify("parity.io:443") - .is_ok()); - assert!(AllowHosts::Only(vec![Authority::try_from("parity.io").unwrap()].into()) - .verify("parity.io:443") - .is_err()); + fn authority_host_and_uri_mismatch() { + let req = + hyper::Request::builder().header(HOST, "example.com:9999").uri("example.com").body(Body::empty()).unwrap(); + assert!(Authority::from_http_request(&req).is_none()); } #[test] - fn should_support_wildcards() { - assert!((AllowHosts::Only(vec![Authority::try_from("*.web3.site:*").unwrap()].into())) - .verify("parity.web3.site:8180") - .is_ok()); - assert!((AllowHosts::Only(vec![Authority::try_from("*.web3.site:*").unwrap()].into())) - .verify("parity.web3.site") - .is_ok()); - } - - #[test] - fn should_accept_with_and_without_default_port() { - assert!(AllowHosts::Only(vec![Authority::try_from("https://parity.io:443").unwrap()].into()) - .verify("https://parity.io") - .is_ok()); - - assert!(AllowHosts::Only(vec![Authority::try_from("https://parity.io").unwrap()].into()) - .verify("https://parity.io:443") - .is_ok()); + fn authority_missing_host_and_uri() { + let req = hyper::Request::builder().body(Body::empty()).unwrap(); + assert!(Authority::from_http_request(&req).is_none()); } } diff --git a/server/src/middleware/host_filter.rs b/server/src/middleware/host_filter.rs new file mode 100644 index 0000000000..417582778a --- /dev/null +++ b/server/src/middleware/host_filter.rs @@ -0,0 +1,181 @@ +// Copyright 2019-2023 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//! HTTP host validation middleware. + +use crate::middleware::authority::{Authority, AuthorityError, Port}; +use crate::transport::http; +use futures_util::{Future, FutureExt, TryFutureExt}; +use hyper::{Body, Request, Response}; +use route_recognizer::Router; +use std::error::Error as StdError; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tower::{Layer, Service}; + +/// Middleware to enable host filtering. +#[derive(Debug)] +pub struct HostFilterLayer(Arc); + +impl HostFilterLayer { + /// Enables host filtering and allow only the specified hosts. + pub fn new, U: TryInto>(allow_only: T) -> Result + where + T: IntoIterator, + U: TryInto, + { + let allow_only: Result, _> = allow_only.into_iter().map(|a| a.try_into()).collect(); + Ok(Self(Arc::new(WhitelistedHosts::from(allow_only?)))) + } +} + +impl Layer for HostFilterLayer { + type Service = HostFilter; + + fn layer(&self, inner: S) -> Self::Service { + HostFilter { inner, filter: self.0.clone() } + } +} + +/// Middleware to enable host filtering. +#[derive(Debug)] +pub struct HostFilter { + inner: S, + filter: Arc, +} + +impl Service> for HostFilter +where + S: Service, Response = Response>, + S::Response: 'static, + S::Error: Into> + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = Box; + type Future = Pin> + Send + 'static>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, request: Request) -> Self::Future { + let Some(authority) = Authority::from_http_request(&request) else { + return async { Ok(http::response::malformed()) }.boxed(); + }; + + if self.filter.recognize(&authority) { + Box::pin(self.inner.call(request).map_err(Into::into)) + } else { + tracing::debug!("Denied request: {:?}", request); + async { Ok(http::response::host_not_allowed()) }.boxed() + } + } +} + +/// Represent the URL patterns that is whitelisted. +#[derive(Default, Debug, Clone)] +pub struct WhitelistedHosts(Router); + +impl From for WhitelistedHosts +where + T: IntoIterator, +{ + fn from(value: T) -> Self { + let mut router = Router::new(); + + for auth in value.into_iter() { + router.add(&auth.host, auth.port); + } + + Self(router) + } +} + +impl WhitelistedHosts { + fn recognize(&self, other: &Authority) -> bool { + if let Ok(p) = self.0.recognize(&other.host) { + let p = p.handler(); + + match (p, &other.port) { + (Port::Any, _) => true, + (Port::Default, Port::Default) => true, + (Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true, + _ => false, + } + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::{Authority, WhitelistedHosts}; + + fn unwrap_auth(a: &str) -> Authority { + a.try_into().unwrap() + } + + fn unwrap_filter(list: &[&str]) -> WhitelistedHosts { + let l: Vec<_> = list.into_iter().map(|&a| a.try_into().unwrap()).collect(); + WhitelistedHosts::from(l) + } + + #[test] + fn should_reject_if_header_not_on_the_list() { + let filter = unwrap_filter(&[]); + assert!(!filter.recognize(&unwrap_auth("parity.io"))); + } + + #[test] + fn should_accept_if_on_the_list() { + let filter = unwrap_filter(&["parity.io"]); + assert!(filter.recognize(&unwrap_auth("parity.io"))); + } + + #[test] + fn should_accept_if_on_the_list_with_port() { + let filter = unwrap_filter(&["parity.io:443"]); + assert!(filter.recognize(&unwrap_auth("parity.io:443"))); + assert!(!filter.recognize(&unwrap_auth("parity.io"))); + } + + #[test] + fn should_support_wildcards() { + let filter = unwrap_filter(&["*.web3.site:*"]); + assert!(filter.recognize(&unwrap_auth("parity.web3.site:8180"))); + assert!(filter.recognize(&unwrap_auth("parity.web3.site"))); + } + + #[test] + fn should_accept_with_and_without_default_port() { + let filter = unwrap_filter(&["https://parity.io:443"]); + assert!(filter.recognize(&unwrap_auth("https://parity.io"))); + assert!(filter.recognize(&unwrap_auth("https://parity.io:443"))); + } +} diff --git a/server/src/middleware/mod.rs b/server/src/middleware/mod.rs index d1a829e423..868aea9cba 100644 --- a/server/src/middleware/mod.rs +++ b/server/src/middleware/mod.rs @@ -1,4 +1,38 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + //! Various middleware implementations for RPC specific purposes. +/// Utility and types related to the authority of an URI. +mod authority; +/// HTTP Host filtering middleware. +mod host_filter; /// Proxy `GET /path` to internal RPC methods. -pub mod proxy_get_request; +mod proxy_get_request; + +pub use authority::*; +pub use host_filter::*; +pub use proxy_get_request::*; diff --git a/server/src/middleware/proxy_get_request.rs b/server/src/middleware/proxy_get_request.rs index 1a46e7e76b..40acd11f5d 100644 --- a/server/src/middleware/proxy_get_request.rs +++ b/server/src/middleware/proxy_get_request.rs @@ -1,3 +1,29 @@ +// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any +// person obtaining a copy of this software and associated +// documentation files (the "Software"), to deal in the +// Software without restriction, including without +// limitation the rights to use, copy, modify, merge, +// publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software +// is furnished to do so, subject to the following +// conditions: +// +// The above copyright notice and this permission notice +// shall be included in all copies or substantial portions +// of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + //! Middleware that proxies requests at a specified URI to internal //! RPC method calls. diff --git a/server/src/server.rs b/server/src/server.rs index c844bbba54..b9b1bafafb 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -34,7 +34,6 @@ use std::time::Duration; use crate::future::{ConnectionGuard, ServerHandle, StopHandle}; use crate::logger::{Logger, TransportProtocol}; -use crate::transport::http::fetch_authority; use crate::transport::{http, ws}; use futures_util::future::{self, Either, FutureExt}; @@ -43,7 +42,7 @@ use futures_util::io::{BufReader, BufWriter}; use hyper::body::HttpBody; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; -use jsonrpsee_core::server::{AllowHosts, Authority, AuthorityError, Methods, WhitelistedHosts}; +use jsonrpsee_core::server::Methods; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES}; @@ -128,7 +127,6 @@ where let max_response_body_size = self.cfg.max_response_body_size; let max_log_length = self.cfg.max_log_length; let max_subscriptions_per_connection = self.cfg.max_subscriptions_per_connection; - let allow_hosts = self.cfg.allow_hosts; let logger = self.logger; let batch_requests_config = self.cfg.batch_requests_config; let id_provider = self.id_provider; @@ -148,7 +146,6 @@ where let data = ProcessConnection { remote_addr, methods: methods.clone(), - allow_hosts: allow_hosts.clone(), max_request_body_size, max_response_body_size, max_log_length, @@ -209,8 +206,6 @@ struct Settings { max_log_length: u32, /// Maximum number of subscriptions per connection. max_subscriptions_per_connection: u32, - /// Host filtering. - allow_hosts: AllowHosts, /// Whether batch requests are supported by this server or not. batch_requests_config: BatchRequestConfig, /// Custom tokio runtime to run the server on. @@ -245,7 +240,6 @@ impl Default for Settings { max_connections: MAX_CONNECTIONS, max_subscriptions_per_connection: 1024, batch_requests_config: BatchRequestConfig::Unlimited, - allow_hosts: AllowHosts::Any, tokio_runtime: None, ping_interval: Duration::from_secs(60), enable_http: true, @@ -420,30 +414,6 @@ impl Builder { self } - /// Enables host filtering and allow only the specified hosts. - /// - /// Default: no host filtering is enabled. - pub fn host_filter, U: TryInto>( - mut self, - allow_only: T, - ) -> Result - where - T: IntoIterator, - U: TryInto, - { - let allow_only: Result, _> = allow_only.into_iter().map(|a| a.try_into()).collect(); - self.settings.allow_hosts = AllowHosts::Only(WhitelistedHosts::from(allow_only?)); - Ok(self) - } - - /// Disable host filtering and allow all. - /// - /// Default: no host filtering is enabled. - pub fn disable_host_filtering(mut self) -> Self { - self.settings.allow_hosts = AllowHosts::Any; - self - } - /// Configure a custom [`tower::ServiceBuilder`] middleware for composing layers to be applied to the RPC service. /// /// Default: No tower layers are applied to the RPC service. @@ -592,8 +562,6 @@ pub(crate) struct ServiceData { pub(crate) remote_addr: SocketAddr, /// Registered server methods. pub(crate) methods: Methods, - /// Access control. - pub(crate) allow_hosts: AllowHosts, /// Max request body size. pub(crate) max_request_body_size: u32, /// Max response body size. @@ -652,15 +620,6 @@ impl hyper::service::Service> for TowerSe fn call(&mut self, request: hyper::Request) -> Self::Future { tracing::trace!("{:?}", request); - let Some(authority) = fetch_authority(&request) else { - return async { Ok(http::response::malformed()) }.boxed(); - }; - - if let Err(e) = self.inner.allow_hosts.verify(authority) { - tracing::debug!("Denied request: {}", e); - return async { Ok(http::response::host_not_allowed()) }.boxed(); - } - let is_upgrade_request = is_upgrade_request(&request); if self.inner.enable_ws && is_upgrade_request { @@ -727,8 +686,6 @@ struct ProcessConnection { remote_addr: SocketAddr, /// Registered server methods. methods: Methods, - /// Access control. - allow_hosts: AllowHosts, /// Max request body size. max_request_body_size: u32, /// Max response body size. @@ -806,7 +763,6 @@ fn process_connection<'a, L: Logger, B, U>( inner: ServiceData { remote_addr: cfg.remote_addr, methods: cfg.methods, - allow_hosts: cfg.allow_hosts, max_request_body_size: cfg.max_request_body_size, max_response_body_size: cfg.max_response_body_size, max_log_length: cfg.max_log_length, diff --git a/server/src/transport/http.rs b/server/src/transport/http.rs index b9bb87c75e..be41b683b8 100644 --- a/server/src/transport/http.rs +++ b/server/src/transport/http.rs @@ -9,7 +9,7 @@ use futures_util::future::Either; use futures_util::stream::{FuturesOrdered, StreamExt}; use hyper::Method; use jsonrpsee_core::error::GenericTransportError; -use jsonrpsee_core::http_helpers::{self, read_body}; +use jsonrpsee_core::http_helpers::read_body; use jsonrpsee_core::server::helpers::{ batch_response_error, prepare_error, BatchResponseBuilder, MethodResponse, MethodResponseResult, }; @@ -50,20 +50,6 @@ pub(crate) async fn reject_connection(socket: tokio::net::TcpStream) { } } -/// The `Authority` can be sent by the client in the `Host header` or in the `URI` -/// such that we must check both. -pub(crate) fn fetch_authority(request: &hyper::Request) -> Option<&str> { - let host_header = http_helpers::read_header_value(request.headers(), hyper::header::HOST); - let uri = request.uri().authority(); - - match (host_header, uri) { - (Some(a1), Some(a2)) if a1 == a2.as_str() => Some(a1), - (Some(a), None) => Some(a), - (None, Some(a)) => Some(a.as_str()), - _ => None, - } -} - #[derive(Debug)] pub(crate) struct ProcessValidatedRequest<'a, L: Logger> { pub(crate) request: hyper::Request, @@ -117,9 +103,9 @@ pub(crate) async fn process_validated_request( BatchRequestConfig::Disabled => { let response = MethodResponse::error( Id::Null, - ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None), + ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None), ); - logger.on_response(&response.result, request_start, TransportProtocol::WebSocket); + logger.on_response(&response.result, request_start, TransportProtocol::Http); return response::ok_response(response.result); } BatchRequestConfig::Limit(limit) => limit as usize, diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 742c92a809..ffd8ed7448 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -514,7 +514,7 @@ async fn execute_unchecked_call(params: ExecuteCallParams) { BatchRequestConfig::Disabled => { let response = MethodResponse::error( Id::Null, - ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None), + ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None), ); logger.on_response(&response.result, request_start, TransportProtocol::WebSocket); _ = sink.send(response.result).await; diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 26f9e48cf0..269e71f0b8 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -31,9 +31,9 @@ use std::time::Duration; use futures::{SinkExt, Stream, StreamExt}; use jsonrpsee::core::Error; -use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; +use jsonrpsee::server::middleware::ProxyGetRequestLayer; use jsonrpsee::server::{ - PendingSubscriptionSink, RpcModule, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError, + PendingSubscriptionSink, RpcModule, Server, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError, }; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; use jsonrpsee::SubscriptionCloseResponse; @@ -195,26 +195,17 @@ pub async fn server_with_sleeping_subscription(tx: futures::channel::mpsc::Sende #[allow(dead_code)] pub async fn server_with_health_api() -> (SocketAddr, ServerHandle) { - server_with_access_control(None, CorsLayer::new()).await + server_with_cors(CorsLayer::new()).await } -pub async fn server_with_access_control( - allowed_hosts: Option>, - cors: CorsLayer, -) -> (SocketAddr, ServerHandle) { +pub async fn server_with_cors(cors: CorsLayer) -> (SocketAddr, ServerHandle) { let middleware = tower::ServiceBuilder::new() // Proxy `GET /health` requests to internal `system_health` method. .layer(ProxyGetRequestLayer::new("/health", "system_health").unwrap()) // Add `CORS` layer. .layer(cors); - let mut builder = jsonrpsee::server::Server::builder(); - - if let Some(filter) = allowed_hosts { - builder = builder.host_filter(filter).unwrap(); - } - - let server = builder.set_middleware(middleware).build("127.0.0.1:0").await.unwrap(); + let server = Server::builder().set_middleware(middleware).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 81459e9e26..80235f968c 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -36,8 +36,8 @@ use std::time::Duration; use futures::stream::FuturesUnordered; use futures::{channel::mpsc, StreamExt, TryStreamExt}; use helpers::{ - init_logger, pipe_from_stream_and_drop, server, server_with_access_control, server_with_health_api, - server_with_subscription, server_with_subscription_and_handle, + init_logger, pipe_from_stream_and_drop, server, server_with_cors, server_with_health_api, server_with_subscription, + server_with_subscription_and_handle, }; use hyper::http::HeaderValue; use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT}; @@ -45,6 +45,7 @@ use jsonrpsee::core::params::{ArrayParams, BatchRequestBuilder}; use jsonrpsee::core::server::SubscriptionMessage; use jsonrpsee::core::{Error, JsonValue}; use jsonrpsee::http_client::HttpClientBuilder; +use jsonrpsee::server::middleware::HostFilterLayer; use jsonrpsee::server::{ServerBuilder, ServerHandle}; use jsonrpsee::types::error::{ErrorObject, UNKNOWN_ERROR_CODE}; use jsonrpsee::ws_client::WsClientBuilder; @@ -905,7 +906,7 @@ async fn http_cors_preflight_works() { .allow_methods([Method::POST]) .allow_origin("https://foo.com".parse::().unwrap()) .allow_headers([hyper::header::CONTENT_TYPE]); - let (server_addr, _handle) = server_with_access_control(None, cors).await; + let (server_addr, _handle) = server_with_cors(cors).await; let http_client = Client::new(); let uri = format!("http://{}", server_addr); @@ -1007,12 +1008,10 @@ async fn ws_host_filtering_wildcard_works() { init_logger(); - let server = ServerBuilder::default() - .host_filter(["http://localhost:*".to_string(), "http://127.0.0.1:*".to_string()]) - .unwrap() - .build("127.0.0.1:0") - .await - .unwrap(); + let middleware = + tower::ServiceBuilder::new().layer(HostFilterLayer::new(["http://localhost:*", "http://127.0.0.1:*"]).unwrap()); + + let server = ServerBuilder::default().set_middleware(middleware).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); @@ -1031,12 +1030,10 @@ async fn http_host_filtering_wildcard_works() { init_logger(); - let server = ServerBuilder::default() - .host_filter(vec!["http://localhost:*", "http://127.0.0.1:*"]) - .unwrap() - .build("127.0.0.1:0") - .await - .unwrap(); + let middleware = + tower::ServiceBuilder::new().layer(HostFilterLayer::new(["http://localhost:*", "http://127.0.0.1:*"]).unwrap()); + + let server = ServerBuilder::default().set_middleware(middleware).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); @@ -1055,8 +1052,9 @@ async fn deny_invalid_host() { init_logger(); - let server = - ServerBuilder::default().host_filter(["http://example.com"]).unwrap().build("127.0.0.1:0").await.unwrap(); + let middleware = tower::ServiceBuilder::new().layer(HostFilterLayer::new(["example.com"]).unwrap()); + + let server = Server::builder().set_middleware(middleware).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap();