diff --git a/README.md b/README.md index 96e4d1cd1e0..afd0452a408 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,9 @@ Here's the user data to change the message of the day setting, as we did in the motd = "my own value!" ``` +If your user data is over the size limit of the platform (e.g. 16KiB for EC2) you can compress the contents with gzip. +(With [aws-cli](https://aws.amazon.com/cli/), you can use `--user-data fileb:///path/to/gz-file` to pass binary data.) + ### Description of settings Here we'll describe each setting you can change. diff --git a/sources/Cargo.lock b/sources/Cargo.lock index e3f8371b0c6..817cab229d8 100644 --- a/sources/Cargo.lock +++ b/sources/Cargo.lock @@ -949,7 +949,10 @@ dependencies = [ "apiclient", "base64 0.13.0", "cargo-readme", + "flate2", + "hex-literal", "http", + "lazy_static", "log", "reqwest", "serde", diff --git a/sources/api/early-boot-config/Cargo.toml b/sources/api/early-boot-config/Cargo.toml index 7368ff8fcb0..46067334abc 100644 --- a/sources/api/early-boot-config/Cargo.toml +++ b/sources/api/early-boot-config/Cargo.toml @@ -12,6 +12,7 @@ exclude = ["README.md"] [dependencies] apiclient = { path = "../apiclient" } base64 = "0.13" +flate2 = { version = "1.0", default-features = false, features = ["rust_backend"] } http = "0.2" log = "0.4" reqwest = { version = "0.10", default-features = false, features = ["blocking"] } @@ -27,3 +28,7 @@ toml = "0.5" [build-dependencies] cargo-readme = "3.1" + +[dev-dependencies] +hex-literal = "0.3" +lazy_static = "1.4" diff --git a/sources/api/early-boot-config/src/compression.rs b/sources/api/early-boot-config/src/compression.rs new file mode 100644 index 00000000000..2c0ee052b4d --- /dev/null +++ b/sources/api/early-boot-config/src/compression.rs @@ -0,0 +1,226 @@ +//! This module supports reading from an input source that could be compressed or plain text. +//! +//! Currently gzip compression is supported. + +use flate2::read::GzDecoder; +use std::fs::File; +use std::io::{BufReader, Chain, Cursor, ErrorKind, Read, Result, Take}; +use std::path::Path; + +/// "File magic" that indicates file type is stored in a few bytes at the start at the start of the +/// data. For now we only need two bytes for gzip, but if adding new formats, we'd need to read +/// more. (The simplest approach may be to read the max length for any format we need and compare +/// the appropriate prefix length.) +/// https://en.wikipedia.org/wiki/List_of_file_signatures +const MAGIC_LEN: usize = 2; + +// We currently only support gzip, but it shouldn't be hard to add more. +/// These bytes are at the start of any gzip-compressed data. +const GZ_MAGIC: [u8; 2] = [0x1f, 0x8b]; + +/// This helper takes a slice of bytes representing UTF-8 text, which can optionally be +/// compressed, and returns an uncompressed string. +pub fn expand_slice_maybe(input: &[u8]) -> Result { + let mut output = String::new(); + let mut reader = OptionalCompressionReader::new(Cursor::new(input)); + reader.read_to_string(&mut output)?; + Ok(output) +} + +/// This helper takes the path to a file containing UTF-8 text, which can optionally be compressed, +/// and returns an uncompressed string of all its contents. File reads are done through BufReader. +pub fn expand_file_maybe

(path: P) -> Result +where + P: AsRef, +{ + let path = path.as_ref(); + let file = File::open(&path)?; + let mut output = String::new(); + let mut reader = OptionalCompressionReader::new(BufReader::new(file)); + reader.read_to_string(&mut output)?; + Ok(output) +} + +// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= + +/// This type lets you wrap a `Read` whose data may or may not be compressed, and its `read()` +/// calls will uncompress the data if needed. +pub struct OptionalCompressionReader(CompressionType); + +/// This represents the type of compression we've detected within a `Read`, or `Unknown` if we +/// haven't yet read any bytes to be able to detect it. +enum CompressionType { + /// This represents the starting state of the reader before we've read the magic bytes and + /// detected any compression. + /// + /// We need ownership of the `Read` to construct one of the variants below, so we use an + /// `Option` to allow `take`ing the value out, even if we only have a &mut reference in the + /// `read` implementation. This is safe because detection is a one-time process and we know we + /// construct this with Some value. + Unknown(Option), + + /// We haven't found recognizable compression. + None(Peek), + + /// We found gzip compression. + Gz(GzDecoder>), +} + +/// `Peek` lets us read the starting bytes (the "magic") of an input `Read` but maintain those +/// bytes in an internal buffer. We Take the number of bytes we read (to handle reads shorter than +/// MAGIC_LEN) and Chain them together with the rest of the input, to represent the full input. +type Peek = Chain>, T>; + +impl OptionalCompressionReader { + /// Build a new `OptionalCompressionReader` before we know the input compression type. + pub fn new(input: R) -> Self { + Self(CompressionType::Unknown(Some(input))) + } +} + +/// Implement `Read` by checking whether we've detected compression type yet, and if not, detecting +/// it and then replacing ourselves with the appropriate type so we can continue reading. +impl Read for OptionalCompressionReader { + fn read(&mut self, buf: &mut [u8]) -> Result { + match self.0 { + CompressionType::Unknown(ref mut input) => { + // Take ownership of our `Read` object so we can store it in a new variant. + let mut reader = input.take().expect( + "OptionalCompressionReader constructed with None input; programming error", + ); + + // Read the "magic" that tells us the compression type. + let mut magic = [0u8; MAGIC_LEN]; + let count = reader.retry_read(&mut magic)?; + + // We need to return all of the bytes, but we just consumed MAGIC_LEN of them. + // This chains together those initial bytes with the remainder so we have them all. + let magic_read = Cursor::new(magic).take(count as u64); + let full_input = magic_read.chain(reader); + + // Detect compression type based on the magic bytes. + if count == MAGIC_LEN && magic == GZ_MAGIC { + // Use a gzip decoder if gzip compressed. + self.0 = CompressionType::Gz(GzDecoder::new(full_input)) + } else { + // We couldn't detect any compression; just read the input. + self.0 = CompressionType::None(full_input) + } + + // We've replaced Unknown with a known compression type; defer to that for reading. + self.read(buf) + } + + // After initial detection, we just perform standard reads on the reader we prepared. + CompressionType::None(ref mut r) => r.read(buf), + CompressionType::Gz(ref mut r) => r.read(buf), + } + } +} + +// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= + +/// This trait represents a `Read` operation where we want to retry after standard interruptions +/// (unlike `read()`) but also need to know the number of bytes we read (unlike `read_exact()`). +trait RetryRead { + fn retry_read(&mut self, buf: &mut [u8]) -> Result; +} + +impl RetryRead for R { + // This implementation is based on stdlib Read::read_exact, but hitting EOF isn't a failure, we + // just want to return the number of bytes we could read. + fn retry_read(&mut self, mut buf: &mut [u8]) -> Result { + let mut count = 0; + + // Read until we have no more space in the output buffer + while !buf.is_empty() { + match self.read(buf) { + // No bytes left, done + Ok(0) => break, + // Read n bytes, slide ahead n in the output buffer and read more + Ok(n) => { + count += n; + let tmp = buf; + buf = &mut tmp[n..]; + } + // Retry on interrupt + Err(e) if e.kind() == ErrorKind::Interrupted => {} + // Other failures are fatal + Err(e) => return Err(e), + } + } + + Ok(count) + } +} + +// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= + +#[cfg(test)] +mod test { + use super::*; + use hex_literal::hex; + use lazy_static::lazy_static; + use std::io::Cursor; + + lazy_static! { + /// Some plain text strings and their gzip encodings. + static ref DATA: &'static [(&'static str, &'static [u8])] = &[ + ("", &hex!("1f8b 0808 3863 3960 0003 656d 7074 7900 0300 0000 0000 0000 0000")), + ("4", &hex!("1f8b 0808 6f63 3960 0003 666f 7572 0033 0100 381b b6f3 0100 0000")), + ("42", &hex!("1f8b 0808 7c6b 3960 0003 616e 7377 6572 0033 3102 0088 b024 3202 0000 00")), + ("hi there", &hex!("1f8b 0808 d24f 3960 0003 6869 7468 6572 6500 cbc8 5428 c948 2d4a 0500 ec76 a3e3 0800 0000")), + ]; + } + + #[test] + fn test_plain() { + for (plain, _gz) in *DATA { + let input = Cursor::new(plain); + let mut output = String::new(); + OptionalCompressionReader::new(input) + .read_to_string(&mut output) + .unwrap(); + assert_eq!(output, *plain); + } + } + + #[test] + fn test_gz() { + for (plain, gz) in *DATA { + let input = Cursor::new(gz); + let mut output = String::new(); + OptionalCompressionReader::new(input) + .read_to_string(&mut output) + .unwrap(); + assert_eq!(output, *plain); + } + } + + #[test] + fn test_helper_plain() { + for (plain, _gz) in *DATA { + assert_eq!(expand_slice_maybe(plain.as_bytes()).unwrap(), *plain); + } + } + + #[test] + fn test_helper_gz() { + for (plain, gz) in *DATA { + assert_eq!(expand_slice_maybe(gz).unwrap(), *plain); + } + } + + #[test] + fn test_magic_prefix() { + // Confirm that if we give a prefix of valid magic, but not the whole thing, we just get + // that input back. + let input = Cursor::new(&[0x1f]); + let mut output = Vec::new(); + let count = OptionalCompressionReader::new(input) + .read_to_end(&mut output) + .unwrap(); + assert_eq!(count, 1); + assert_eq!(output, &[0x1f]); + } +} diff --git a/sources/api/early-boot-config/src/main.rs b/sources/api/early-boot-config/src/main.rs index 732d51fd90b..f36017e9087 100644 --- a/sources/api/early-boot-config/src/main.rs +++ b/sources/api/early-boot-config/src/main.rs @@ -21,6 +21,7 @@ use std::fs; use std::str::FromStr; use std::{env, process}; +mod compression; mod provider; mod settings; use crate::provider::PlatformDataProvider; diff --git a/sources/api/early-boot-config/src/provider/aws.rs b/sources/api/early-boot-config/src/provider/aws.rs index 0c9b7dee73f..f650e58f404 100644 --- a/sources/api/early-boot-config/src/provider/aws.rs +++ b/sources/api/early-boot-config/src/provider/aws.rs @@ -1,6 +1,7 @@ //! The aws module implements the `PlatformDataProvider` trait for gathering userdata on AWS. use super::{PlatformDataProvider, SettingsJson}; +use crate::compression::expand_slice_maybe; use http::StatusCode; use reqwest::blocking::Client; use serde_json::json; @@ -48,7 +49,7 @@ impl AwsDataProvider { session_token: &str, uri: &str, description: &str, - ) -> Result> { + ) -> Result>> { debug!("Requesting {} from {}", description, uri); let response = client .get(uri) @@ -57,15 +58,35 @@ impl AwsDataProvider { .context(error::Request { method: "GET", uri })?; trace!("IMDS response: {:?}", &response); + // IMDS data can be larger than we'd want to log (50k+ compressed) so we don't necessarily + // want to show the whole thing, and don't want to show binary data. + fn response_string(response: &[u8]) -> String { + // arbitrary max len; would be nice to print the start of the data if it's + // uncompressed, but we'd need to break slice at a safe point for UTF-8, and without + // reading in the whole thing like String::from_utf8. + if response.len() > 2048 { + "".to_string() + } else if let Ok(s) = String::from_utf8(response.into()) { + s + } else { + "".to_string() + } + } + match response.status() { code @ StatusCode::OK => { info!("Received {}", description); - let response_body = response.text().context(error::ResponseBody { - method: "GET", - uri, - code, - })?; - trace!("Response text: {:?}", &response_body); + let response_body = response + .bytes() + .context(error::ResponseBody { + method: "GET", + uri, + code, + })? + .to_vec(); + + let response_str = response_string(&response_body); + trace!("Response: {:?}", response_str); Ok(Some(response_body)) } @@ -74,18 +95,24 @@ impl AwsDataProvider { StatusCode::NOT_FOUND => Ok(None), code @ _ => { - let response_body = response.text().context(error::ResponseBody { - method: "GET", - uri, - code, - })?; - trace!("Response text: {:?}", &response_body); + let response_body = response + .bytes() + .context(error::ResponseBody { + method: "GET", + uri, + code, + })? + .to_vec(); + + let response_str = response_string(&response_body); + + trace!("Response: {:?}", response_str); error::Response { method: "GET", uri, code, - response_body, + response_body: response_str, } .fail() } @@ -98,11 +125,13 @@ impl AwsDataProvider { let desc = "user data"; let uri = Self::USER_DATA_ENDPOINT; - let user_data_str = match Self::fetch_imds(client, session_token, uri, desc) { + let user_data_raw = match Self::fetch_imds(client, session_token, uri, desc) { Err(e) => return Err(e), Ok(None) => return Ok(None), Ok(Some(s)) => s, }; + let user_data_str = expand_slice_maybe(&user_data_raw) + .context(error::Decompression { what: "user data" })?; trace!("Received user data: {}", user_data_str); // Remove outer "settings" layer before sending to API @@ -131,7 +160,9 @@ impl AwsDataProvider { match Self::fetch_imds(client, session_token, uri, desc) { Err(e) => return Err(e), Ok(None) => return Ok(None), - Ok(Some(s)) => s, + Ok(Some(raw)) => { + expand_slice_maybe(&raw).context(error::Decompression { what: "user data" })? + } } }; trace!("Received instance identity document: {}", iid_str); @@ -198,6 +229,9 @@ mod error { #[snafu(display("Response '{}' from '{}': {}", get_bad_status_code(&source), uri, source))] BadResponse { uri: String, source: reqwest::Error }, + #[snafu(display("Failed to decompress {}: {}", what, source))] + Decompression { what: String, source: io::Error }, + #[snafu(display("Error deserializing from JSON: {}", source))] DeserializeJson { source: serde_json::error::Error }, diff --git a/sources/api/early-boot-config/src/provider/cdrom.rs b/sources/api/early-boot-config/src/provider/cdrom.rs index d9927691e07..89a695d7efa 100644 --- a/sources/api/early-boot-config/src/provider/cdrom.rs +++ b/sources/api/early-boot-config/src/provider/cdrom.rs @@ -2,11 +2,13 @@ //! mounted CDRom. use super::{PlatformDataProvider, SettingsJson}; +use crate::compression::{expand_file_maybe, expand_slice_maybe, OptionalCompressionReader}; use serde::Deserialize; use snafu::{ensure, OptionExt, ResultExt}; use std::ffi::OsStr; -use std::fs::{self, File}; +use std::fs::File; use std::io::BufReader; +use std::iter::FromIterator; use std::path::Path; pub(crate) struct CdromDataProvider; @@ -52,15 +54,29 @@ impl CdromDataProvider { // Since we only look for a specific list of file names, we should never find a file // with an extension we don't understand. Some(_) => unreachable!(), - None => fs::read_to_string(&user_data_file).context(error::InputFileRead { - path: user_data_file, - })?, + None => { + // Read the file, decompressing it if compressed. + expand_file_maybe(&user_data_file).context(error::InputFileRead { + path: &user_data_file, + })? + } }; if user_data_str.is_empty() { return Ok(None); } - trace!("Received user data: {}", user_data_str); + + // User data could be 700MB compressed! Eek! :) + if user_data_str.len() <= 2048 { + trace!("Received user data: {}", user_data_str); + } else { + trace!( + "Received long user data, starts with: {}", + // (this isn't perfect because chars aren't grapheme clusters, but will error + // toward printing the whole input, which is fine) + String::from_iter(user_data_str.chars().take(2048)) + ); + } // Remove outer "settings" layer before sending to API let mut val: toml::Value = @@ -83,7 +99,7 @@ impl CdromDataProvider { fn ovf_user_data>(path: P) -> Result { let path = path.as_ref(); let file = File::open(path).context(error::InputFileRead { path })?; - let reader = BufReader::new(file); + let reader = OptionalCompressionReader::new(BufReader::new(file)); // Deserialize the OVF file, dropping everything we don't care about let ovf: Environment = @@ -109,12 +125,12 @@ impl CdromDataProvider { base64_string: base64_str.to_string(), })?; - // Create a valid utf8 str - let decoded = std::str::from_utf8(&decoded_bytes).context(error::InvalidUTF8 { - base64_string: base64_str.to_string(), + // Decompress the data if it's compressed + let decoded = expand_slice_maybe(&decoded_bytes).context(error::Decompression { + what: "OVF user data", })?; - Ok(decoded.to_string()) + Ok(decoded) } } @@ -168,19 +184,12 @@ mod error { source: base64::DecodeError, }, + #[snafu(display("Failed to decompress {}: {}", what, source))] + Decompression { what: String, source: io::Error }, + #[snafu(display("Unable to read input file '{}': {}", path.display(), source))] InputFileRead { path: PathBuf, source: io::Error }, - #[snafu(display( - "Invalid (non-utf8) output from base64 string '{}': {}", - base64_string, - source - ))] - InvalidUTF8 { - base64_string: String, - source: std::str::Utf8Error, - }, - #[snafu(display("Error serializing TOML to JSON: {}", source))] SettingsToJSON { source: serde_json::error::Error }, diff --git a/sources/api/early-boot-config/src/provider/local_file.rs b/sources/api/early-boot-config/src/provider/local_file.rs index 4ee63e72b2f..2604db97424 100644 --- a/sources/api/early-boot-config/src/provider/local_file.rs +++ b/sources/api/early-boot-config/src/provider/local_file.rs @@ -2,8 +2,8 @@ //! local file use super::{PlatformDataProvider, SettingsJson}; +use crate::compression::expand_file_maybe; use snafu::{OptionExt, ResultExt}; -use std::fs; pub(crate) struct LocalFileDataProvider; @@ -16,8 +16,9 @@ impl PlatformDataProvider for LocalFileDataProvider { let mut output = Vec::new(); info!("'{}' exists, using it", Self::USER_DATA_FILE); + // Read the file, decompressing it if compressed. let user_data_str = - fs::read_to_string(Self::USER_DATA_FILE).context(error::InputFileRead { + expand_file_maybe(Self::USER_DATA_FILE).context(error::InputFileRead { path: Self::USER_DATA_FILE, })?;