Skip to content

Commit

Permalink
[http server]: use similar API for host and origin filtering as WS (#…
Browse files Browse the repository at this point in the history
…473)

* use similar API for HTTP ACL builder

* revert unintentional change

* fix nits

* Update http-server/src/access_control/mod.rs

Co-authored-by: David <[email protected]>

* grumbles

Co-authored-by: David <[email protected]>
  • Loading branch information
niklasad1 and dvdplm authored Oct 11, 2021
1 parent b3e4297 commit 7a9ebc9
Showing 1 changed file with 79 additions and 49 deletions.
128 changes: 79 additions & 49 deletions http-server/src/access_control/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,34 @@ pub(crate) mod cors;
pub(crate) mod hosts;
mod matcher;

use hosts::{AllowHosts, Host};
use crate::types::Error;

use cors::{AccessControlAllowHeaders, AccessControlAllowOrigin};
use hosts::{AllowHosts, Host};
use hyper::header;
use jsonrpsee_utils::http_helpers;

/// Define access on control on HTTP layer.
#[derive(Clone, Debug)]
pub struct AccessControl {
allow_hosts: AllowHosts,
cors_allow_origin: Option<Vec<AccessControlAllowOrigin>>,
cors_allow_headers: AccessControlAllowHeaders,
allowed_hosts: AllowHosts,
allowed_origins: Option<Vec<AccessControlAllowOrigin>>,
allowed_headers: AccessControlAllowHeaders,
continue_on_invalid_cors: bool,
}

impl AccessControl {
/// Validate incoming request by http HOST
pub fn deny_host(&self, request: &hyper::Request<hyper::Body>) -> bool {
!hosts::is_host_valid(http_helpers::read_header_value(request.headers(), "host"), &self.allow_hosts)
!hosts::is_host_valid(http_helpers::read_header_value(request.headers(), "host"), &self.allowed_hosts)
}

/// Validate incoming request by CORS origin
pub fn deny_cors_origin(&self, request: &hyper::Request<hyper::Body>) -> bool {
let header = cors::get_cors_allow_origin(
http_helpers::read_header_value(request.headers(), "origin"),
http_helpers::read_header_value(request.headers(), "host"),
&self.cors_allow_origin,
&self.allowed_origins,
)
.map(|origin| {
use self::cors::AccessControlAllowOrigin::*;
Expand All @@ -79,7 +80,7 @@ impl AccessControl {
.flat_map(|val| val.split(", "))
.flat_map(|val| val.split(','));

let header = cors::get_cors_allow_headers(headers, requested_headers, &self.cors_allow_headers, |name| {
let header = cors::get_cors_allow_headers(headers, requested_headers, &self.allowed_headers, |name| {
header::HeaderValue::from_str(name).unwrap_or_else(|_| header::HeaderValue::from_static("unknown"))
});
header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors
Expand All @@ -89,9 +90,9 @@ impl AccessControl {
impl Default for AccessControl {
fn default() -> Self {
Self {
allow_hosts: AllowHosts::Any,
cors_allow_origin: None,
cors_allow_headers: AccessControlAllowHeaders::Any,
allowed_hosts: AllowHosts::Any,
allowed_origins: None,
allowed_headers: AccessControlAllowHeaders::Any,
continue_on_invalid_cors: false,
}
}
Expand All @@ -100,18 +101,18 @@ impl Default for AccessControl {
/// Convenience builder pattern
#[derive(Debug)]
pub struct AccessControlBuilder {
allow_hosts: AllowHosts,
cors_allow_origin: Option<Vec<AccessControlAllowOrigin>>,
cors_allow_headers: AccessControlAllowHeaders,
allowed_hosts: AllowHosts,
allowed_origins: Option<Vec<AccessControlAllowOrigin>>,
allowed_headers: AccessControlAllowHeaders,
continue_on_invalid_cors: bool,
}

impl Default for AccessControlBuilder {
fn default() -> Self {
Self {
allow_hosts: AllowHosts::Any,
cors_allow_origin: None,
cors_allow_headers: AccessControlAllowHeaders::Any,
allowed_hosts: AllowHosts::Any,
allowed_origins: None,
allowed_headers: AccessControlAllowHeaders::Any,
continue_on_invalid_cors: false,
}
}
Expand All @@ -123,46 +124,75 @@ impl AccessControlBuilder {
Self::default()
}

/// Configure allow host.
pub fn allow_host(mut self, host: Host) -> Self {
let allow_hosts = match self.allow_hosts {
AllowHosts::Any => vec![host],
AllowHosts::Only(mut allow_hosts) => {
allow_hosts.push(host);
allow_hosts
}
};
self.allow_hosts = AllowHosts::Only(allow_hosts);
/// Allow all hosts.
pub fn allow_all_hosts(mut self) -> Self {
self.allowed_hosts = AllowHosts::Any;
self
}

/// Configure CORS origin.
pub fn cors_allow_origin(mut self, allow_origin: AccessControlAllowOrigin) -> Self {
let cors_allow_origin = match self.cors_allow_origin {
Some(mut cors_allow_origin) => {
cors_allow_origin.push(allow_origin);
cors_allow_origin
}
None => vec![allow_origin],
};
self.cors_allow_origin = Some(cors_allow_origin);
/// Allow all origins.
pub fn allow_all_origins(mut self) -> Self {
self.allowed_headers = AccessControlAllowHeaders::Any;
self
}

/// Configure which CORS header that is allowed.
pub fn cors_allow_header(mut self, header: String) -> Self {
let allow_headers = match self.cors_allow_headers {
AccessControlAllowHeaders::Any => vec![header],
AccessControlAllowHeaders::Only(mut allow_headers) => {
allow_headers.push(header);
allow_headers
}
};
self.cors_allow_headers = AccessControlAllowHeaders::Only(allow_headers);
/// Allow all headers.
pub fn allow_all_headers(mut self) -> Self {
self.allowed_origins = None;
self
}

/// Configure allowed hosts.
///
/// Default - allow all.
pub fn set_allowed_hosts<List, H>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = H>,
H: Into<Host>,
{
let allowed_hosts: Vec<Host> = list.into_iter().map(Into::into).collect();
if allowed_hosts.is_empty() {
return Err(Error::EmptyAllowList("Host"));
}
self.allowed_hosts = AllowHosts::Only(allowed_hosts);
Ok(self)
}

/// Configure allowed origins.
///
/// Default - allow all.
pub fn set_allowed_origins<Origin, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Origin>,
Origin: Into<AccessControlAllowOrigin>,
{
let allowed_origins: Vec<AccessControlAllowOrigin> = list.into_iter().map(Into::into).collect();
if allowed_origins.is_empty() {
return Err(Error::EmptyAllowList("Origin"));
}
self.allowed_origins = Some(allowed_origins);
Ok(self)
}

/// Configure allowed CORS headers.
///
/// Default - allow all.
pub fn set_allowed_headers<Header, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Header>,
Header: Into<String>,
{
let allowed_headers: Vec<String> = list.into_iter().map(Into::into).collect();
if allowed_headers.is_empty() {
return Err(Error::EmptyAllowList("Header"));
}
self.allowed_headers = AccessControlAllowHeaders::Only(allowed_headers);
Ok(self)
}

/// Enable or disable to continue with invalid CORS.
///
/// Default: false.
pub fn continue_on_invalid_cors(mut self, continue_on_invalid_cors: bool) -> Self {
self.continue_on_invalid_cors = continue_on_invalid_cors;
self
Expand All @@ -171,9 +201,9 @@ impl AccessControlBuilder {
/// Build.
pub fn build(self) -> AccessControl {
AccessControl {
allow_hosts: self.allow_hosts,
cors_allow_origin: self.cors_allow_origin,
cors_allow_headers: self.cors_allow_headers,
allowed_hosts: self.allowed_hosts,
allowed_origins: self.allowed_origins,
allowed_headers: self.allowed_headers,
continue_on_invalid_cors: self.continue_on_invalid_cors,
}
}
Expand Down

0 comments on commit 7a9ebc9

Please sign in to comment.