Skip to content

Commit

Permalink
Add ability to inherit an env var on the CLI (#6746)
Browse files Browse the repository at this point in the history
In addition to Wasmtime's preexisting support for `--env FOO=bar` this
commit additionally adds support for `--env FOO` which is inspired by
Docker to inherit the environment variable `FOO` from the calling
process as a shortcut for `--env FOO=$FOO`.
  • Loading branch information
alexcrichton authored Jul 19, 2023
1 parent 475d1ba commit 2f6cec9
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/commands/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ fn parse_module(s: OsString) -> anyhow::Result<PathBuf> {
}
}

fn parse_env_var(s: &str) -> Result<(String, String)> {
let parts: Vec<_> = s.splitn(2, '=').collect();
if parts.len() != 2 {
bail!("must be of the form `key=value`");
}
Ok((parts[0].to_owned(), parts[1].to_owned()))
fn parse_env_var(s: &str) -> Result<(String, Option<String>)> {
let mut parts = s.splitn(2, '=');
Ok((
parts.next().unwrap().to_string(),
parts.next().map(|s| s.to_string()),
))
}

fn parse_map_dirs(s: &str) -> Result<(String, String)> {
Expand Down Expand Up @@ -156,9 +156,15 @@ pub struct RunCommand {
#[clap(long = "dir", number_of_values = 1, value_name = "DIRECTORY")]
dirs: Vec<String>,

/// Pass an environment variable to the program
#[clap(long = "env", number_of_values = 1, value_name = "NAME=VAL", value_parser = parse_env_var)]
vars: Vec<(String, String)>,
/// Pass an environment variable to the program.
///
/// The `--env FOO=BAR` form will set the environment variable named `FOO`
/// to the value `BAR` for the guest program using WASI. The `--env FOO`
/// form will set the environment variable named `FOO` to the same value it
/// has in the calling process for the guest, or in other words it will
/// cause the environment variable `FOO` to be inherited.
#[clap(long = "env", number_of_values = 1, value_name = "NAME[=VAL]", value_parser = parse_env_var)]
vars: Vec<(String, Option<String>)>,

/// The name of the function to run
#[clap(long, value_name = "FUNCTION")]
Expand Down Expand Up @@ -689,7 +695,7 @@ fn populate_with_wasi(
module: Module,
preopen_dirs: Vec<(String, Dir)>,
argv: &[String],
vars: &[(String, String)],
vars: &[(String, Option<String>)],
wasi_modules: &WasiModules,
listenfd: bool,
mut tcplisten: Vec<TcpListener>,
Expand All @@ -698,7 +704,16 @@ fn populate_with_wasi(
wasmtime_wasi::add_to_linker(linker, |host| host.wasi.as_mut().unwrap())?;

let mut builder = WasiCtxBuilder::new();
builder = builder.inherit_stdio().args(argv)?.envs(vars)?;
builder = builder.inherit_stdio().args(argv)?;

for (key, value) in vars {
let value = match value {
Some(value) => value.clone(),
None => std::env::var(key)
.map_err(|_| anyhow!("environment varialbe `{key}` not found"))?,
};
builder = builder.env(key, &value)?;
}

let mut num_fd: usize = 3;

Expand Down
44 changes: 44 additions & 0 deletions tests/all/cli_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,50 @@ fn hello_wasi_snapshot0_from_stdin() -> Result<()> {
Ok(())
}

#[test]
fn specify_env() -> Result<()> {
// By default no env is inherited
let output = get_wasmtime_command()?
.args(&["run", "tests/all/cli_tests/print_env.wat"])
.env("THIS_WILL_NOT", "show up in the output")
.output()?;
assert!(output.status.success());
assert_eq!(String::from_utf8_lossy(&output.stdout), "");

// Specify a single env var
let output = get_wasmtime_command()?
.args(&[
"run",
"--env",
"FOO=bar",
"tests/all/cli_tests/print_env.wat",
])
.output()?;
assert!(output.status.success());
assert_eq!(String::from_utf8_lossy(&output.stdout), "FOO=bar\n");

// Inherit a single env var
let output = get_wasmtime_command()?
.args(&["run", "--env", "FOO", "tests/all/cli_tests/print_env.wat"])
.env("FOO", "bar")
.output()?;
assert!(output.status.success());
assert_eq!(String::from_utf8_lossy(&output.stdout), "FOO=bar\n");

// Inherit a nonexistent env var
let output = get_wasmtime_command()?
.args(&[
"run",
"--env",
"SURELY_THIS_ENV_VAR_DOES_NOT_EXIST_ANYWHERE_RIGHT",
"tests/all/cli_tests/print_env.wat",
])
.output()?;
assert!(!output.status.success());

Ok(())
}

#[cfg(unix)]
#[test]
fn run_cwasm_from_stdin() -> Result<()> {
Expand Down
76 changes: 76 additions & 0 deletions tests/all/cli_tests/print_env.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
(module
(import "wasi_snapshot_preview1" "fd_write"
(func $fd_write (param i32 i32 i32 i32) (result i32)))

(import "wasi_snapshot_preview1" "environ_get"
(func $environ_get (param i32 i32) (result i32)))

(memory (export "memory") 1)

(func (export "_start")
(local $envptrs i32)
(local $envmem i32)
(local $env i32)

(local.set $envptrs (i32.mul (memory.grow (i32.const 1)) (i32.const 65536)))
(local.set $envmem (i32.mul (memory.grow (i32.const 1)) (i32.const 65536)))

(if (i32.ne
(call $environ_get (local.get $envptrs) (local.get $envmem))
(i32.const 0))
(unreachable))

(loop
(local.set $env (i32.load (local.get $envptrs)))
(local.set $envptrs (i32.add (local.get $envptrs) (i32.const 4)))
(if (i32.eq (local.get $env) (i32.const 0)) (return))

(call $write_all (local.get $env) (call $strlen (local.get $env)))
(call $write_all (i32.const 10) (i32.const 1))
br 0
)
)

(func $write_all (param $ptr i32) (param $len i32)
(local $rc i32)
(local $iov i32)
(local $written i32)

(local.set $written (i32.const 80))
(local.set $iov (i32.const 100))

(loop
(local.get $len)
if
(i32.store (local.get $iov) (local.get $ptr))
(i32.store offset=4 (local.get $iov) (local.get $len))
(local.set $rc
(call $fd_write
(i32.const 1)
(local.get $iov)
(i32.const 1)
(local.get $written)))
(if (i32.ne (local.get $rc) (i32.const 0)) (unreachable))

(local.set $len (i32.sub (local.get $len) (i32.load (local.get $written))))
(local.set $ptr (i32.add (local.get $ptr) (i32.load (local.get $written))))
end
)
)

(func $strlen (param $ptr i32) (result i32)
(local $len i32)
(loop
(i32.load8_u (i32.add (local.get $ptr) (local.get $len)))
if
(local.set $len (i32.add (local.get $len) (i32.const 1)))
br 1
end
)
local.get $len
)

(data (i32.const 10) "\n")
)


0 comments on commit 2f6cec9

Please sign in to comment.