diff --git a/Cargo.lock b/Cargo.lock index 0907f9fdf..a499608b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -412,6 +412,21 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "nts-pool-ke" +version = "1.0.0" +dependencies = [ + "ntp-proto", + "rustls", + "rustls-pemfile", + "serde", + "thiserror", + "tokio", + "toml", + "tracing", + "tracing-subscriber", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" diff --git a/Cargo.toml b/Cargo.toml index 61130df27..e1e1f1e45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "ntp-proto", "ntp-os-clock", "ntp-udp", + "nts-pool-ke", "ntpd" ] exclude = [ ] diff --git a/nts-pool-ke/Cargo.toml b/nts-pool-ke/Cargo.toml new file mode 100644 index 000000000..f475183a3 --- /dev/null +++ b/nts-pool-ke/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "nts-pool-ke" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +readme.workspace = true +description.workspace = true +publish.workspace = true +rust-version.workspace = true + +[dependencies] +tokio = { workspace = true, features = ["rt-multi-thread", "io-util", "io-std", "fs", "sync", "net", "macros", "time"] } +toml.workspace = true +tracing.workspace = true +tracing-subscriber = { version = "0.3.0", default-features = false, features = ["std", "fmt", "ansi"] } +rustls.workspace = true +rustls-pemfile.workspace = true +serde.workspace = true +ntp-proto.workspace = true +thiserror.workspace = true + +[[bin]] +name = "nts-pool-ke" +path = "bin/nts-pool-ke.rs" diff --git a/nts-pool-ke/bin/nts-pool-ke.rs b/nts-pool-ke/bin/nts-pool-ke.rs new file mode 100644 index 000000000..c6c956d68 --- /dev/null +++ b/nts-pool-ke/bin/nts-pool-ke.rs @@ -0,0 +1,5 @@ +#[tokio::main] +async fn main() -> ! { + let result = nts_pool_ke::nts_pool_ke_main().await; + std::process::exit(if result.is_ok() { 0 } else { 1 }); +} diff --git a/nts-pool-ke/src/cli.rs b/nts-pool-ke/src/cli.rs new file mode 100644 index 000000000..17a8f8963 --- /dev/null +++ b/nts-pool-ke/src/cli.rs @@ -0,0 +1,190 @@ +use crate::daemon_tracing::LogLevel; +use std::path::PathBuf; +use std::str::FromStr; + +const USAGE_MSG: &str = "\ +usage: nts-pool-ke [-c PATH] [-l LOG_LEVEL] + nts-pool-ke -h + nts-pool-ke -v"; + +const DESCRIPTOR: &str = "ntp-daemon - synchronize system time"; + +const HELP_MSG: &str = "Options: + -c, --config=PATH change the config .toml file + -l, --log-level=LOG_LEVEL change the log level + -h, --help display this help text + -v, --version display version information"; + +pub fn long_help_message() -> String { + format!("{DESCRIPTOR}\n\n{USAGE_MSG}\n\n{HELP_MSG}") +} + +#[derive(Debug, Default)] +pub(crate) struct NtsPoolKeOptions { + /// Path of the configuration file + pub config: Option, + /// Level for messages to display in logs + pub log_level: Option, + help: bool, + version: bool, + pub action: NtsPoolKeAction, +} + +pub enum CliArg { + Flag(String), + Argument(String, String), + Rest(Vec), +} + +impl CliArg { + pub fn normalize_arguments( + takes_argument: &[&str], + takes_argument_short: &[char], + iter: I, + ) -> Result, String> + where + I: IntoIterator, + { + // the first argument is the sudo command - so we can skip it + let mut arg_iter = iter.into_iter().skip(1); + let mut processed = vec![]; + let mut rest = vec![]; + + while let Some(arg) = arg_iter.next() { + match arg.as_str() { + "--" => { + rest.extend(arg_iter); + break; + } + long_arg if long_arg.starts_with("--") => { + // --config=/path/to/config.toml + let invalid = Err(format!("invalid option: '{long_arg}'")); + + if let Some((key, value)) = long_arg.split_once('=') { + if takes_argument.contains(&key) { + processed.push(CliArg::Argument(key.to_string(), value.to_string())) + } else { + invalid? + } + } else if takes_argument.contains(&long_arg) { + if let Some(next) = arg_iter.next() { + processed.push(CliArg::Argument(long_arg.to_string(), next)) + } else { + Err(format!("'{}' expects an argument", &long_arg))?; + } + } else { + processed.push(CliArg::Flag(arg)); + } + } + short_arg if short_arg.starts_with('-') => { + // split combined shorthand options + for (n, char) in short_arg.trim_start_matches('-').chars().enumerate() { + let flag = format!("-{char}"); + // convert option argument to seperate segment + if takes_argument_short.contains(&char) { + let rest = short_arg[(n + 2)..].trim().to_string(); + // assignment syntax is not accepted for shorthand arguments + if rest.starts_with('=') { + Err("invalid option '='")?; + } + if !rest.is_empty() { + processed.push(CliArg::Argument(flag, rest)); + } else if let Some(next) = arg_iter.next() { + processed.push(CliArg::Argument(flag, next)); + } else if char == 'h' { + // short version of --help has no arguments + processed.push(CliArg::Flag(flag)); + } else { + Err(format!("'-{}' expects an argument", char))?; + } + break; + } else { + processed.push(CliArg::Flag(flag)); + } + } + } + _argument => rest.push(arg), + } + } + + if !rest.is_empty() { + processed.push(CliArg::Rest(rest)); + } + + Ok(processed) + } +} + +#[derive(Debug, Default, PartialEq, Eq)] +pub enum NtsPoolKeAction { + #[default] + Help, + Version, + Run, +} + +impl NtsPoolKeOptions { + const TAKES_ARGUMENT: &'static [&'static str] = &["--config", "--log-level"]; + const TAKES_ARGUMENT_SHORT: &'static [char] = &['c', 'l']; + + /// parse an iterator over command line arguments + pub fn try_parse_from(iter: I) -> Result + where + I: IntoIterator, + T: AsRef + Clone, + { + let mut options = NtsPoolKeOptions::default(); + let arg_iter = CliArg::normalize_arguments( + Self::TAKES_ARGUMENT, + Self::TAKES_ARGUMENT_SHORT, + iter.into_iter().map(|x| x.as_ref().to_string()), + )? + .into_iter() + .peekable(); + + for arg in arg_iter { + match arg { + CliArg::Flag(flag) => match flag.as_str() { + "-h" | "--help" => { + options.help = true; + } + "-v" | "--version" => { + options.version = true; + } + option => { + Err(format!("invalid option provided: {option}"))?; + } + }, + CliArg::Argument(option, value) => match option.as_str() { + "-c" | "--config" => { + options.config = Some(PathBuf::from(value)); + } + "-l" | "--log-level" => match LogLevel::from_str(&value) { + Ok(level) => options.log_level = Some(level), + Err(_) => return Err("invalid log level".into()), + }, + option => { + Err(format!("invalid option provided: {option}"))?; + } + }, + CliArg::Rest(_rest) => { /* do nothing, drop remaining arguments */ } + } + } + + options.resolve_action(); + // nothing to validate at the moment + + Ok(options) + } + + /// from the arguments resolve which action should be performed + fn resolve_action(&mut self) { + if self.help { + self.action = NtsPoolKeAction::Help; + } else if self.version { + self.action = NtsPoolKeAction::Version; + } else { + self.action = NtsPoolKeAction::Run; + } + } +} diff --git a/nts-pool-ke/src/config.rs b/nts-pool-ke/src/config.rs new file mode 100644 index 000000000..b752df07f --- /dev/null +++ b/nts-pool-ke/src/config.rs @@ -0,0 +1,104 @@ +use std::{ + net::SocketAddr, + os::unix::fs::PermissionsExt, + path::{Path, PathBuf}, +}; + +use serde::Deserialize; +use thiserror::Error; +use tracing::{info, warn}; + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +pub struct Config { + pub nts_pool_ke_server: NtsPoolKeConfig, + #[serde(default)] + pub observability: ObservabilityConfig, +} + +#[derive(Error, Debug)] +pub enum ConfigError { + #[error("io error while reading config: {0}")] + Io(#[from] std::io::Error), + #[error("config toml parsing error: {0}")] + Toml(#[from] toml::de::Error), +} + +impl Config { + pub fn check(&self) -> bool { + true + } + + async fn from_file(file: impl AsRef) -> Result { + let meta = std::fs::metadata(&file)?; + let perm = meta.permissions(); + + const S_IWOTH: u32 = 2; + if perm.mode() & S_IWOTH != 0 { + warn!("Unrestricted config file permissions: Others can write."); + } + + let contents = tokio::fs::read_to_string(file).await?; + Ok(toml::de::from_str(&contents)?) + } + + pub async fn from_args(file: impl AsRef) -> Result { + let path = file.as_ref(); + info!(?path, "using config file"); + + let config = Config::from_file(path).await?; + + Ok(config) + } +} + +#[derive(Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +pub struct ObservabilityConfig { + #[serde(default)] + pub log_level: Option, +} + +#[derive(Debug, PartialEq, Eq, Clone, Deserialize)] +#[serde(rename_all = "kebab-case", deny_unknown_fields)] +pub struct NtsPoolKeConfig { + pub certificate_chain_path: PathBuf, + pub private_key_path: PathBuf, + #[serde(default = "default_nts_ke_timeout")] + pub key_exchange_timeout_ms: u64, + pub listen: SocketAddr, +} + +fn default_nts_ke_timeout() -> u64 { + 1000 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_deserialize_nts_pool_ke() { + let test: Config = toml::from_str( + r#" + [nts-pool-ke-server] + listen = "0.0.0.0:4460" + certificate-chain-path = "/foo/bar/baz.pem" + private-key-path = "spam.der" + "#, + ) + .unwrap(); + + let pem = PathBuf::from("/foo/bar/baz.pem"); + assert_eq!(test.nts_pool_ke_server.certificate_chain_path, pem); + assert_eq!( + test.nts_pool_ke_server.private_key_path, + PathBuf::from("spam.der") + ); + assert_eq!(test.nts_pool_ke_server.key_exchange_timeout_ms, 1000,); + assert_eq!( + test.nts_pool_ke_server.listen, + "0.0.0.0:4460".parse().unwrap(), + ); + } +} diff --git a/nts-pool-ke/src/lib.rs b/nts-pool-ke/src/lib.rs new file mode 100644 index 000000000..7db6b5c81 --- /dev/null +++ b/nts-pool-ke/src/lib.rs @@ -0,0 +1,211 @@ +mod cli; +mod config; + +mod tracing; + +use std::{io::BufRead, path::PathBuf, sync::Arc}; + +use cli::NtsPoolKeOptions; +use config::{Config, NtsPoolKeConfig}; +use tokio::net::{TcpListener, ToSocketAddrs}; + +use crate::tracing as daemon_tracing; +use daemon_tracing::LogLevel; +use tracing_subscriber::util::SubscriberInitExt; + +pub(crate) mod exitcode { + /// An internal software error has been detected. This + /// should be limited to non-operating system related + /// errors as possible. + pub const SOFTWARE: i32 = 70; + + /// Something was found in an unconfigured or misconfigured state. + pub const CONFIG: i32 = 78; +} + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +pub async fn nts_pool_ke_main() -> Result<(), Box> { + let options = NtsPoolKeOptions::try_parse_from(std::env::args())?; + + match options.action { + cli::NtsPoolKeAction::Help => { + println!("{}", cli::long_help_message()); + } + cli::NtsPoolKeAction::Version => { + eprintln!("nts-pool-ke {VERSION}"); + } + cli::NtsPoolKeAction::Run => run(options).await?, + } + + Ok(()) +} + +// initializes the logger so that logs during config parsing are reported. Then it overrides the +// log level based on the config if required. +pub(crate) async fn initialize_logging_parse_config( + initial_log_level: Option, + config_path: Option, +) -> Config { + let mut log_level = initial_log_level.unwrap_or_default(); + + let config_tracing = daemon_tracing::tracing_init(log_level); + let config = ::tracing::subscriber::with_default(config_tracing, || { + async { + match config_path { + None => { + eprintln!("no configuration path specified"); + std::process::exit(exitcode::CONFIG); + } + Some(config_path) => { + match Config::from_args(config_path).await { + Ok(c) => c, + Err(e) => { + // print to stderr because tracing is not yet setup + eprintln!("There was an error loading the config: {e}"); + std::process::exit(exitcode::CONFIG); + } + } + } + } + } + }) + .await; + + if let Some(config_log_level) = config.observability.log_level { + if initial_log_level.is_none() { + log_level = config_log_level; + } + } + + // set a default global subscriber from now on + let tracing_inst = daemon_tracing::tracing_init(log_level); + tracing_inst.init(); + + config +} + +async fn run(options: NtsPoolKeOptions) -> Result<(), Box> { + let config = initialize_logging_parse_config(options.log_level, options.config).await; + + // give the user a warning that we use the command line option + if config.observability.log_level.is_some() && options.log_level.is_some() { + ::tracing::info!("Log level override from command line arguments is active"); + } + + // Warn/error if the config is unreasonable. We do this after finishing + // tracing setup to ensure logging is fully configured. + config.check(); + + let result = run_nts_pool_ke(config.nts_pool_ke_server).await; + + match result { + Ok(v) => Ok(v), + Err(e) => { + ::tracing::error!("Abnormal termination of NTS KE server: {e}"); + std::process::exit(exitcode::SOFTWARE) + } + } +} + +async fn run_nts_pool_ke(nts_pool_ke_config: NtsPoolKeConfig) -> std::io::Result<()> { + let certificate_chain_file = std::fs::File::open(&nts_pool_ke_config.certificate_chain_path) + .map_err(|e| { + io_error(&format!( + "error reading certificate_chain_path at `{:?}`: {:?}", + nts_pool_ke_config.certificate_chain_path, e + )) + })?; + + let private_key_file = + std::fs::File::open(&nts_pool_ke_config.private_key_path).map_err(|e| { + io_error(&format!( + "error reading key_der_path at `{:?}`: {:?}", + nts_pool_ke_config.private_key_path, e + )) + })?; + + let cert_chain: Vec = + rustls_pemfile::certs(&mut std::io::BufReader::new(certificate_chain_file))? + .into_iter() + .map(rustls::Certificate) + .collect(); + + let private_key = private_key_from_bufread(&mut std::io::BufReader::new(private_key_file))? + .ok_or(io_error("could not parse private key"))?; + + pool_key_exchange_server( + nts_pool_ke_config.listen, + cert_chain, + private_key, + nts_pool_ke_config.key_exchange_timeout_ms, + ) + .await +} + +fn io_error(msg: &str) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, msg) +} + +async fn pool_key_exchange_server( + address: impl ToSocketAddrs, + certificate_chain: Vec, + private_key: rustls::PrivateKey, + timeout_ms: u64, +) -> std::io::Result<()> { + let listener = TcpListener::bind(address).await?; + + let mut config = rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certificate_chain, private_key) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?; + + config.alpn_protocols.clear(); + config.alpn_protocols.push(b"ntske/1".to_vec()); + + let config = Arc::new(config); + + loop { + let (stream, peer_address) = listener.accept().await?; + + let config = config.clone(); + + let fut = async move { + // BoundKeyExchangeServer::run(stream, config) + // .await + // .map_err(|ke_error| std::io::Error::new(std::io::ErrorKind::Other, ke_error)) + let _ = stream; + let _ = config; + + std::io::Result::Ok(()) + }; + + tokio::spawn(async move { + let timeout = std::time::Duration::from_millis(timeout_ms); + match tokio::time::timeout(timeout, fut).await { + Err(_) => ::tracing::debug!(?peer_address, "NTS Pool KE timed out"), + Ok(Err(err)) => ::tracing::debug!(?err, ?peer_address, "NTS Pool KE failed"), + Ok(Ok(())) => ::tracing::debug!(?peer_address, "NTS Pool KE completed"), + } + }); + } +} + +fn private_key_from_bufread( + mut reader: impl BufRead, +) -> std::io::Result> { + use rustls_pemfile::Item; + + loop { + match rustls_pemfile::read_one(&mut reader)? { + Some(Item::RSAKey(key)) => return Ok(Some(rustls::PrivateKey(key))), + Some(Item::PKCS8Key(key)) => return Ok(Some(rustls::PrivateKey(key))), + Some(Item::ECKey(key)) => return Ok(Some(rustls::PrivateKey(key))), + None => break, + _ => {} + } + } + + Ok(None) +} diff --git a/nts-pool-ke/src/tracing.rs b/nts-pool-ke/src/tracing.rs new file mode 100644 index 000000000..1624f1e2b --- /dev/null +++ b/nts-pool-ke/src/tracing.rs @@ -0,0 +1,69 @@ +use std::str::FromStr; + +use serde::Deserialize; +use tracing::metadata::LevelFilter; + +#[derive(Debug, Default, Copy, Clone, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// The "trace" level. + /// + /// Designates very low priority, often extremely verbose, information. + Trace = 0, + /// The "debug" level. + /// + /// Designates lower priority information. + Debug = 1, + /// The "info" level. + /// + /// Designates useful information. + #[default] + Info = 2, + /// The "warn" level. + /// + /// Designates hazardous situations. + Warn = 3, + /// The "error" level. + /// + /// Designates very serious errors. + Error = 4, +} + +pub struct UnknownLogLevel; + +impl FromStr for LogLevel { + type Err = UnknownLogLevel; + + fn from_str(s: &str) -> Result { + match s { + "trace" => Ok(LogLevel::Trace), + "debug" => Ok(LogLevel::Debug), + "info" => Ok(LogLevel::Info), + "warn" => Ok(LogLevel::Warn), + "error" => Ok(LogLevel::Error), + _ => Err(UnknownLogLevel), + } + } +} + +impl From for tracing::Level { + fn from(value: LogLevel) -> Self { + match value { + LogLevel::Trace => tracing::Level::TRACE, + LogLevel::Debug => tracing::Level::DEBUG, + LogLevel::Info => tracing::Level::INFO, + LogLevel::Warn => tracing::Level::WARN, + LogLevel::Error => tracing::Level::ERROR, + } + } +} + +impl From for LevelFilter { + fn from(value: LogLevel) -> Self { + LevelFilter::from_level(value.into()) + } +} + +pub fn tracing_init(level: impl Into) -> tracing_subscriber::fmt::Subscriber { + tracing_subscriber::fmt().with_max_level(level).finish() +}