Skip to content

Commit

Permalink
feat(cli,serverless): forward client IP through X-Forwarded-For header (
Browse files Browse the repository at this point in the history
  • Loading branch information
QuiiBz authored Oct 28, 2022
1 parent b271c2a commit 4d368dc
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 13 deletions.
6 changes: 6 additions & 0 deletions .changeset/kind-boats-attack.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@lagon/cli': patch
'@lagon/serverless': patch
---

Forward client IP through X-Forwarded-For header
17 changes: 14 additions & 3 deletions packages/cli/src/commands/dev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use colored::Colorize;
use envfile::EnvFile;
use hyper::body::Bytes;
use hyper::http::response::Builder;
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server};
use lagon_runtime::http::{Request, Response, RunResult, StreamResult};
Expand Down Expand Up @@ -41,6 +42,7 @@ fn parse_environment_variables(env: Option<PathBuf>) -> io::Result<HashMap<Strin
// threads to manage.
async fn handle_request(
req: HyperRequest<Body>,
ip: String,
content: Arc<Mutex<(FileCursor, HashMap<String, FileCursor>)>>,
environment_variables: HashMap<String, String>,
) -> Result<HyperResponse<Body>, Infallible> {
Expand Down Expand Up @@ -87,7 +89,8 @@ async fn handle_request(

tx.send_async(RunResult::Response(response)).await.unwrap();
} else {
let request = Request::from_hyper(req).await;
let mut request = Request::from_hyper(req).await;
request.add_header("X-Forwarded-For".into(), ip);

let mut isolate = Isolate::new(
IsolateOptions::new(String::from_utf8(index.get_ref().to_vec()).unwrap())
Expand Down Expand Up @@ -193,13 +196,21 @@ pub async fn dev(
let server_content = content.clone();
let environment_variables = parse_environment_variables(env)?;

let server = Server::bind(&addr).serve(make_service_fn(move |_conn| {
let server = Server::bind(&addr).serve(make_service_fn(move |conn: &AddrStream| {
let content = server_content.clone();
let environment_variables = environment_variables.clone();

let addr = conn.remote_addr();
let ip = addr.ip().to_string();

async move {
Ok::<_, Infallible>(service_fn(move |req| {
handle_request(req, content.clone(), environment_variables.clone())
handle_request(
req,
ip.clone(),
content.clone(),
environment_variables.clone(),
)
}))
}
}));
Expand Down
6 changes: 6 additions & 0 deletions packages/runtime/src/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,10 @@ impl Request {
url,
}
}

pub fn add_header(&mut self, key: String, value: String) {
if let Some(ref mut headers) = self.headers {
headers.insert(key, value);
}
}
}
36 changes: 26 additions & 10 deletions packages/serverless/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use deployments::Deployment;
use hyper::body::Bytes;
use hyper::header::HOST;
use hyper::http::response::Builder;
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server};
use lagon_runtime::http::{Request, RunResult, StreamResult};
use lagon_runtime::isolate::{Isolate, IsolateOptions};
use lagon_runtime::runtime::{Runtime, RuntimeOptions};
use lazy_static::lazy_static;
use log::error;
use metrics::{counter, /*histogram,*/ increment_counter};
use metrics::increment_counter;
use metrics_exporter_prometheus::PrometheusBuilder;
use mysql::{Opts, Pool};
#[cfg(not(debug_assertions))]
Expand All @@ -35,12 +37,14 @@ mod logger;
lazy_static! {
static ref ISOLATES: RwLock<HashMap<usize, HashMap<String, Isolate>>> =
RwLock::new(HashMap::new());
static ref X_FORWARDED_FOR: String = String::from("X-Forwarded-For");
}

const POOL_SIZE: usize = 8;

async fn handle_request(
req: HyperRequest<Body>,
ip: String,
pool: LocalPoolHandle,
deployments: Arc<RwLock<HashMap<String, Deployment>>>,
thread_ids: Arc<RwLock<HashMap<String, usize>>>,
Expand All @@ -49,14 +53,13 @@ async fn handle_request(
// Remove the leading '/' from the url
url.remove(0);

let request = Request::from_hyper(req).await;
let hostname = request
.headers
.as_ref()
let hostname = req
.headers()
.get(HOST)
.unwrap()
.get("host")
.to_str()
.unwrap()
.clone();
.to_string();

let thread_ids_reader = thread_ids.read().await;

Expand Down Expand Up @@ -88,7 +91,8 @@ async fn handle_request(
];

increment_counter!("lagon_requests", &labels);
counter!("lagon_bytes_in", request.len() as u64, &labels);
// TODO: find the right request bytes length
// counter!("lagon_bytes_in", request.len() as u64, &labels);

if let Some(asset) = deployment.assets.iter().find(|asset| *asset == &url) {
let run_result = match handle_asset(deployment, asset) {
Expand All @@ -105,6 +109,9 @@ async fn handle_request(

tx.send_async(run_result).await.unwrap();
} else {
let mut request = Request::from_hyper(req).await;
request.add_header(X_FORWARDED_FOR.to_string(), ip);

// Only acquire the lock when we are sure we have a
// deployment and that the isolate should be called.
// TODO: read() then write() if not present
Expand Down Expand Up @@ -256,14 +263,23 @@ async fn main() {
let pool = LocalPoolHandle::new(POOL_SIZE);
let thread_ids = Arc::new(RwLock::new(HashMap::new()));

let server = Server::bind(&addr).serve(make_service_fn(move |_conn| {
let server = Server::bind(&addr).serve(make_service_fn(move |conn: &AddrStream| {
let deployments = deployments.clone();
let pool = pool.clone();
let thread_ids = thread_ids.clone();

let addr = conn.remote_addr();
let ip = addr.ip().to_string();

async move {
Ok::<_, Infallible>(service_fn(move |req| {
handle_request(req, pool.clone(), deployments.clone(), thread_ids.clone())
handle_request(
req,
ip.clone(),
pool.clone(),
deployments.clone(),
thread_ids.clone(),
)
}))
}
}));
Expand Down

0 comments on commit 4d368dc

Please sign in to comment.