Skip to content

Commit

Permalink
Add config option to specify additional headers for serve
Browse files Browse the repository at this point in the history
  • Loading branch information
oberien committed Nov 12, 2022
1 parent c620d96 commit a06d1dd
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Subheadings to categorize changes are `added, changed, deprecated, removed, fixe
- It is now possible to disable the hashes in output file names with the new `--filehash` flag (for example `cargo build --filehash false`). Alternatively the `build.filehash` setting in `Trunk.toml` or the env var `CARGO_BUILD_FILEHASH` can be used.
- Flags for enabling reference types & weak references in `wasm-bindgen`.
- Added the `data-typescript` attribute to Rust assets. When present, `wasm-bindgen` will emit TS files for the WASM module.
- Added `headers` config option for `trunk serve`.

### changed
- Bump notify to 5.0.0-pre.13, which fixes [notify-rs/notify#356](https://github.com/notify-rs/notify/issues/356)
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ tokio = { version = "1", default-features = false, features = ["full"] }
tokio-stream = { version = "0.1", default-features = false, features = ["fs", "sync"] }
tokio-tungstenite = "0.17"
toml = "0.5"
tower-http = { version = "0.3", features = ["fs", "trace"] }
tower-http = { version = "0.3", features = ["fs", "trace", "set-header"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
which = "4"
Expand Down
2 changes: 2 additions & 0 deletions Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ port = 8080
open = false
# Disable auto-reload of the web app.
no_autoreload = false
# Additional headers set for responses.
#headers = { "test-header" = "header value", "test-header2" = "header value 2" }

[clean]
# The output dir for all final assets.
Expand Down
6 changes: 6 additions & 0 deletions src/config/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ pub struct ConfigOptsServe {
#[clap(long = "no-autoreload")]
#[serde(default)]
pub no_autoreload: bool,
/// Additional headers to send in responses [default: none]
#[clap(skip)]
#[serde(default)]
pub headers: HashMap<String, String>,
}

/// Config options for the serve system.
Expand Down Expand Up @@ -336,6 +340,7 @@ impl ConfigOpts {
proxy_insecure: cli.proxy_insecure,
proxy_ws: cli.proxy_ws,
no_autoreload: cli.no_autoreload,
headers: cli.headers,
};
let cfg = ConfigOpts {
build: None,
Expand Down Expand Up @@ -510,6 +515,7 @@ impl ConfigOpts {
if l.open {
g.open = true;
}
g.headers.extend(l.headers);
Some(g)
}
};
Expand Down
3 changes: 3 additions & 0 deletions src/config/rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ pub struct RtcServe {
pub proxies: Option<Vec<ConfigOptsProxy>>,
/// Whether to disable auto-reload of the web page when a build completes.
pub no_autoreload: bool,
/// Additional headers to include in responses.
pub headers: HashMap<String, String>,
}

impl RtcServe {
Expand Down Expand Up @@ -247,6 +249,7 @@ impl RtcServe {
proxy_ws: opts.proxy_ws,
proxies,
no_autoreload: opts.no_autoreload,
headers: opts.headers,
})
}
}
Expand Down
41 changes: 28 additions & 13 deletions src/serve.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;

use anyhow::{Context, Result};
use axum::body::{self, Body};
use axum::extract::ws::{WebSocket, WebSocketUpgrade};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::http::header::HeaderName;
use axum::http::{HeaderValue, StatusCode};
use axum::response::Response;
use axum::routing::{get, get_service, Router};
use axum::Server;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tower_http::services::{ServeDir, ServeFile};
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::trace::TraceLayer;

use crate::common::SERVER;
Expand Down Expand Up @@ -117,7 +120,7 @@ impl ServeSystem {
&cfg,
build_done_chan,
));
let router = router(state, cfg.clone());
let router = router(state, cfg.clone())?;
let addr = (cfg.address, cfg.port).into();
let server = Server::bind(&addr)
.serve(router.into_make_service())
Expand Down Expand Up @@ -147,6 +150,8 @@ pub struct State {
pub build_done_chan: broadcast::Sender<()>,
/// Whether to disable autoreload
pub no_autoreload: bool,
/// Additional headers to add to responses.
pub headers: HashMap<String, String>,
}

impl State {
Expand All @@ -166,13 +171,14 @@ impl State {
public_url,
build_done_chan,
no_autoreload: cfg.no_autoreload,
headers: cfg.headers.clone(),
}
}
}

/// Build the Trunk router, this includes that static file server, the WebSocket server,
/// (for autoreload & HMR in the future), as well as any user-defined proxies.
fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Router {
fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Result<Router> {
// Build static file server, middleware, error handler & WS route for reloads.
let public_route = if state.public_url == "/" {
&state.public_url
Expand All @@ -183,19 +189,28 @@ fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Router {
.unwrap_or(&state.public_url)
};

let mut serve_dir = get_service(
ServeDir::new(&state.dist_dir).fallback(ServeFile::new(&state.dist_dir.join(INDEX_HTML))),
);
for (key, value) in &state.headers {
let name = HeaderName::from_bytes(key.as_bytes())
.with_context(|| format!("invalid header {:?}", key))?;
let value: HeaderValue = value
.parse()
.with_context(|| format!("invalid header value {:?} for header {}", value, name))?;
serve_dir = serve_dir.layer(SetResponseHeaderLayer::overriding(name, value))
}

let mut router = Router::new()
.fallback(
Router::new().nest(
public_route,
get_service(
ServeDir::new(&state.dist_dir)
.fallback(ServeFile::new(&state.dist_dir.join(INDEX_HTML))),
)
.handle_error(|error| async move {
tracing::error!(?error, "failed serving static file");
StatusCode::INTERNAL_SERVER_ERROR
})
.layer(TraceLayer::new_for_http()),
get_service(serve_dir)
.handle_error(|error| async move {
tracing::error!(?error, "failed serving static file");
StatusCode::INTERNAL_SERVER_ERROR
})
.layer(TraceLayer::new_for_http()),
),
)
.route(
Expand Down Expand Up @@ -268,7 +283,7 @@ fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Router {
}
}

router
Ok(router)
}

async fn handle_ws(mut ws: WebSocket, state: Arc<State>) {
Expand Down

0 comments on commit a06d1dd

Please sign in to comment.