From 7a9ebc918f2e3e8ecaf22eb2bd8bfdcf7f13a11e Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Mon, 11 Oct 2021 14:02:32 +0200 Subject: [PATCH] [http server]: use similar API for host and origin filtering as `WS` (#473) * use similar API for HTTP ACL builder * revert unintentional change * fix nits * Update http-server/src/access_control/mod.rs Co-authored-by: David * grumbles Co-authored-by: David --- http-server/src/access_control/mod.rs | 128 ++++++++++++++++---------- 1 file changed, 79 insertions(+), 49 deletions(-) diff --git a/http-server/src/access_control/mod.rs b/http-server/src/access_control/mod.rs index 7a9a7277a8..b749bfff4e 100644 --- a/http-server/src/access_control/mod.rs +++ b/http-server/src/access_control/mod.rs @@ -30,25 +30,26 @@ pub(crate) mod cors; pub(crate) mod hosts; mod matcher; -use hosts::{AllowHosts, Host}; +use crate::types::Error; use cors::{AccessControlAllowHeaders, AccessControlAllowOrigin}; +use hosts::{AllowHosts, Host}; use hyper::header; use jsonrpsee_utils::http_helpers; /// Define access on control on HTTP layer. #[derive(Clone, Debug)] pub struct AccessControl { - allow_hosts: AllowHosts, - cors_allow_origin: Option>, - cors_allow_headers: AccessControlAllowHeaders, + allowed_hosts: AllowHosts, + allowed_origins: Option>, + allowed_headers: AccessControlAllowHeaders, continue_on_invalid_cors: bool, } impl AccessControl { /// Validate incoming request by http HOST pub fn deny_host(&self, request: &hyper::Request) -> bool { - !hosts::is_host_valid(http_helpers::read_header_value(request.headers(), "host"), &self.allow_hosts) + !hosts::is_host_valid(http_helpers::read_header_value(request.headers(), "host"), &self.allowed_hosts) } /// Validate incoming request by CORS origin @@ -56,7 +57,7 @@ impl AccessControl { let header = cors::get_cors_allow_origin( http_helpers::read_header_value(request.headers(), "origin"), http_helpers::read_header_value(request.headers(), "host"), - &self.cors_allow_origin, + &self.allowed_origins, ) .map(|origin| { use self::cors::AccessControlAllowOrigin::*; @@ -79,7 +80,7 @@ impl AccessControl { .flat_map(|val| val.split(", ")) .flat_map(|val| val.split(',')); - let header = cors::get_cors_allow_headers(headers, requested_headers, &self.cors_allow_headers, |name| { + let header = cors::get_cors_allow_headers(headers, requested_headers, &self.allowed_headers, |name| { header::HeaderValue::from_str(name).unwrap_or_else(|_| header::HeaderValue::from_static("unknown")) }); header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors @@ -89,9 +90,9 @@ impl AccessControl { impl Default for AccessControl { fn default() -> Self { Self { - allow_hosts: AllowHosts::Any, - cors_allow_origin: None, - cors_allow_headers: AccessControlAllowHeaders::Any, + allowed_hosts: AllowHosts::Any, + allowed_origins: None, + allowed_headers: AccessControlAllowHeaders::Any, continue_on_invalid_cors: false, } } @@ -100,18 +101,18 @@ impl Default for AccessControl { /// Convenience builder pattern #[derive(Debug)] pub struct AccessControlBuilder { - allow_hosts: AllowHosts, - cors_allow_origin: Option>, - cors_allow_headers: AccessControlAllowHeaders, + allowed_hosts: AllowHosts, + allowed_origins: Option>, + allowed_headers: AccessControlAllowHeaders, continue_on_invalid_cors: bool, } impl Default for AccessControlBuilder { fn default() -> Self { Self { - allow_hosts: AllowHosts::Any, - cors_allow_origin: None, - cors_allow_headers: AccessControlAllowHeaders::Any, + allowed_hosts: AllowHosts::Any, + allowed_origins: None, + allowed_headers: AccessControlAllowHeaders::Any, continue_on_invalid_cors: false, } } @@ -123,46 +124,75 @@ impl AccessControlBuilder { Self::default() } - /// Configure allow host. - pub fn allow_host(mut self, host: Host) -> Self { - let allow_hosts = match self.allow_hosts { - AllowHosts::Any => vec![host], - AllowHosts::Only(mut allow_hosts) => { - allow_hosts.push(host); - allow_hosts - } - }; - self.allow_hosts = AllowHosts::Only(allow_hosts); + /// Allow all hosts. + pub fn allow_all_hosts(mut self) -> Self { + self.allowed_hosts = AllowHosts::Any; self } - /// Configure CORS origin. - pub fn cors_allow_origin(mut self, allow_origin: AccessControlAllowOrigin) -> Self { - let cors_allow_origin = match self.cors_allow_origin { - Some(mut cors_allow_origin) => { - cors_allow_origin.push(allow_origin); - cors_allow_origin - } - None => vec![allow_origin], - }; - self.cors_allow_origin = Some(cors_allow_origin); + /// Allow all origins. + pub fn allow_all_origins(mut self) -> Self { + self.allowed_headers = AccessControlAllowHeaders::Any; self } - /// Configure which CORS header that is allowed. - pub fn cors_allow_header(mut self, header: String) -> Self { - let allow_headers = match self.cors_allow_headers { - AccessControlAllowHeaders::Any => vec![header], - AccessControlAllowHeaders::Only(mut allow_headers) => { - allow_headers.push(header); - allow_headers - } - }; - self.cors_allow_headers = AccessControlAllowHeaders::Only(allow_headers); + /// Allow all headers. + pub fn allow_all_headers(mut self) -> Self { + self.allowed_origins = None; self } + /// Configure allowed hosts. + /// + /// Default - allow all. + pub fn set_allowed_hosts(mut self, list: List) -> Result + where + List: IntoIterator, + H: Into, + { + let allowed_hosts: Vec = list.into_iter().map(Into::into).collect(); + if allowed_hosts.is_empty() { + return Err(Error::EmptyAllowList("Host")); + } + self.allowed_hosts = AllowHosts::Only(allowed_hosts); + Ok(self) + } + + /// Configure allowed origins. + /// + /// Default - allow all. + pub fn set_allowed_origins(mut self, list: List) -> Result + where + List: IntoIterator, + Origin: Into, + { + let allowed_origins: Vec = list.into_iter().map(Into::into).collect(); + if allowed_origins.is_empty() { + return Err(Error::EmptyAllowList("Origin")); + } + self.allowed_origins = Some(allowed_origins); + Ok(self) + } + + /// Configure allowed CORS headers. + /// + /// Default - allow all. + pub fn set_allowed_headers(mut self, list: List) -> Result + where + List: IntoIterator, + Header: Into, + { + let allowed_headers: Vec = list.into_iter().map(Into::into).collect(); + if allowed_headers.is_empty() { + return Err(Error::EmptyAllowList("Header")); + } + self.allowed_headers = AccessControlAllowHeaders::Only(allowed_headers); + Ok(self) + } + /// Enable or disable to continue with invalid CORS. + /// + /// Default: false. pub fn continue_on_invalid_cors(mut self, continue_on_invalid_cors: bool) -> Self { self.continue_on_invalid_cors = continue_on_invalid_cors; self @@ -171,9 +201,9 @@ impl AccessControlBuilder { /// Build. pub fn build(self) -> AccessControl { AccessControl { - allow_hosts: self.allow_hosts, - cors_allow_origin: self.cors_allow_origin, - cors_allow_headers: self.cors_allow_headers, + allowed_hosts: self.allowed_hosts, + allowed_origins: self.allowed_origins, + allowed_headers: self.allowed_headers, continue_on_invalid_cors: self.continue_on_invalid_cors, } }