Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add configuration possibilities for CORS middleware #705

Merged
merged 9 commits into from
Feb 6, 2025
89 changes: 89 additions & 0 deletions server/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;

use actix_cors::Cors;
use actix_http::Method;
use cidr::{Ipv4Cidr, Ipv6Cidr};
use clap::{ArgGroup, Args, Parser, Subcommand, ValueEnum};

Expand Down Expand Up @@ -411,6 +413,53 @@ pub struct TlsOptions {
pub tls_server_port: u16,
}

pub fn parse_http_method(value: &str) -> Result<actix_http::Method, String> {
Method::from_bytes(value.as_bytes()).map_err(|f| format!("Failed to format method: {f:?}"))
}

#[derive(Args, Debug, Clone)]
pub struct CorsOptions {
#[clap(env, long, value_delimiter = ',')]
pub cors_origin: Option<Vec<String>>,
#[clap(env, long, value_delimiter = ',')]
pub cors_allowed_headers: Option<Vec<String>>,
#[clap(env, long, default_value_t = 172800)]
pub cors_max_age: usize,
#[clap(env, long, value_delimiter = ',')]
pub cors_exposed_headers: Option<Vec<String>>,
#[clap(env, long, value_delimiter = ',', value_parser = parse_http_method)]
pub cors_methods: Option<Vec<actix_http::Method>>,
}

impl CorsOptions {
pub fn middleware(&self) -> Cors {
let mut cors_middleware = Cors::default()
.max_age(self.cors_max_age)
.allow_any_method()
.allow_any_header();
if let Some(origins) = self.cors_origin.clone() {
for origin in origins {
cors_middleware = cors_middleware.allowed_origin(&origin);
}
cors_middleware = cors_middleware.supports_credentials();
} else {
cors_middleware = cors_middleware.allow_any_origin().send_wildcard();
}
if let Some(allowed_headers) = self.cors_allowed_headers.clone() {
for header in allowed_headers {
cors_middleware = cors_middleware.allowed_header(header);
}
}
if let Some(allowed_methods) = self.cors_methods.clone() {
cors_middleware = cors_middleware.allowed_methods(allowed_methods);
}
if let Some(exposed_headers) = self.cors_exposed_headers.clone() {
cors_middleware = cors_middleware.expose_headers(exposed_headers);
}
cors_middleware
}
}

#[derive(Args, Debug, Clone)]
pub struct HttpServerArgs {
/// Which port should this server listen for HTTP traffic on
Expand All @@ -430,6 +479,9 @@ pub struct HttpServerArgs {

#[clap(flatten)]
pub tls: TlsOptions,

#[clap(flatten)]
pub cors: CorsOptions,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -478,6 +530,7 @@ impl HttpServerArgs {

#[cfg(test)]
mod tests {
use actix_web::http;
use clap::Parser;
use tracing::info;
use tracing_test::traced_test;
Expand Down Expand Up @@ -766,6 +819,42 @@ mod tests {
}
}

#[test]
pub fn cors_origin_can_be_set_via_cli() {
let args = vec![
"unleash-edge",
"--cors-origin",
"example.com",
"--cors-origin",
"otherexample.com",
"--cors-origin",
"one.com,two.com",
"edge",
"-u http://localhost:4242",
];
let args = CliArgs::parse_from(args);
assert_eq!(args.http.cors.cors_origin.clone().unwrap().len(), 4);
let _middleware = args.http.cors.middleware();
}

#[test]
pub fn can_set_custom_cors_method() {
let args = vec![
"unleash-edge",
"--cors-methods",
"GET",
"--cors-methods",
"PATCH",
"edge",
"-u http://localhost:4242",
];
let cli = CliArgs::parse_from(args);
assert_eq!(
cli.http.cors.cors_methods,
Some(vec![http::Method::GET, http::Method::PATCH])
);
}

#[test]
pub fn proxy_trusted_servers_accept_both_ipv4_and_ipv6_cidr_addresses() {
let args = vec![
Expand Down
8 changes: 2 additions & 6 deletions server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::sync::Arc;

use actix_cors::Cors;
use actix_middleware_etag::Etag;
use actix_web::middleware::Logger;
use actix_web::{web, App, HttpServer};
Expand Down Expand Up @@ -51,6 +50,7 @@ async fn main() -> Result<(), anyhow::Error> {
let schedule_args = args.clone();
let mode_arg = args.clone().mode;
let http_args = args.clone().http;
let cors_arg = http_args.cors.clone();
let token_header = args.clone().token_header;
let request_timeout = args.edge_request_timeout;
let keepalive_timeout = args.edge_keepalive_timeout;
Expand Down Expand Up @@ -96,11 +96,7 @@ async fn main() -> Result<(), anyhow::Error> {
let qs_config =
serde_qs::actix::QsQueryConfig::default().qs_config(serde_qs::Config::new(5, false));

let cors_middleware = Cors::default()
.allow_any_origin()
.send_wildcard()
.allow_any_header()
.allow_any_method();
let cors_middleware = cors_arg.middleware();
let mut app = App::new()
.app_data(qs_config)
.app_data(web::Data::new(token_header.clone()))
Expand Down