From c5bf4f0f7063cf31bd9cbfa24ec1c5c21674421e Mon Sep 17 00:00:00 2001 From: Iulian Barbu Date: Mon, 24 Apr 2023 23:59:27 +0300 Subject: [PATCH] cargo-shuttle: added sigint and sigterm handlers --- cargo-shuttle/Cargo.toml | 2 +- cargo-shuttle/src/lib.rs | 541 ++++++++++++++++++++++++++------------- runtime/src/alpha/mod.rs | 17 +- 3 files changed, 373 insertions(+), 187 deletions(-) diff --git a/cargo-shuttle/Cargo.toml b/cargo-shuttle/Cargo.toml index dd60f3e6d3..6408eed723 100644 --- a/cargo-shuttle/Cargo.toml +++ b/cargo-shuttle/Cargo.toml @@ -41,7 +41,7 @@ serde_json = { workspace = true } sqlx = { workspace = true, features = ["runtime-tokio-native-tls", "postgres"] } strum = { workspace = true } tar = { workspace = true } -tokio = { workspace = true, features = ["macros"] } +tokio = { workspace = true, features = ["macros", "signal"] } tokio-tungstenite = { version = "0.18.0", features = ["native-tls"] } toml = { workspace = true } toml_edit = { workspace = true } diff --git a/cargo-shuttle/src/lib.rs b/cargo-shuttle/src/lib.rs index 81ecaff814..3185b3ad75 100644 --- a/cargo-shuttle/src/lib.rs +++ b/cargo-shuttle/src/lib.rs @@ -5,14 +5,18 @@ mod init; mod provisioner_server; use indicatif::ProgressBar; +use shuttle_common::claims::{ClaimService, InjectPropagation}; use shuttle_common::models::deployment::get_deployments_table; use shuttle_common::models::project::{State, IDLE_MINUTES}; use shuttle_common::models::resource::get_resources_table; use shuttle_common::project::ProjectName; use shuttle_common::resource; -use shuttle_proto::runtime::{self, LoadRequest, StartRequest, SubscribeLogsRequest}; +use shuttle_proto::runtime::runtime_client::RuntimeClient; +use shuttle_proto::runtime::{self, LoadRequest, StartRequest, StopRequest, SubscribeLogsRequest}; -use tokio::task::JoinSet; +use tokio::process::Child; +use tokio::task::JoinHandle; +use tonic::transport::Channel; use std::collections::HashMap; use std::ffi::OsString; @@ -449,6 +453,199 @@ impl Shuttle { Ok(()) } + async fn spin_local_runtime( + &self, + run_args: &RunArgs, + service: &BuiltService, + provisioner_server: &JoinHandle>, + i: u16, + provisioner_port: u16, + ) -> Result<( + Child, + RuntimeClient>>, + )> { + let BuiltService { + executable_path, + is_wasm, + working_directory, + .. + } = service.clone(); + + trace!("loading secrets"); + let secrets_path = if working_directory.join("Secrets.dev.toml").exists() { + working_directory.join("Secrets.dev.toml") + } else { + working_directory.join("Secrets.toml") + }; + + let secrets: HashMap = if let Ok(secrets_str) = read_to_string(secrets_path) + { + let secrets: HashMap = + secrets_str.parse::()?.try_into()?; + + trace!(keys = ?secrets.keys(), "available secrets"); + + secrets + } else { + trace!("no Secrets.toml was found"); + Default::default() + }; + + let runtime_path = || { + if is_wasm { + let runtime_path = home::cargo_home() + .expect("failed to find cargo home dir") + .join("bin/shuttle-next"); + + println!("Installing shuttle-next runtime. This can take a while..."); + + if cfg!(debug_assertions) { + // Canonicalized path to shuttle-runtime for dev to work on windows + + let path = std::fs::canonicalize(format!("{MANIFEST_DIR}/../runtime")) + .expect("path to shuttle-runtime does not exist or is invalid"); + + std::process::Command::new("cargo") + .arg("install") + .arg("shuttle-runtime") + .arg("--path") + .arg(path) + .arg("--bin") + .arg("shuttle-next") + .arg("--features") + .arg("next") + .output() + .expect("failed to install the shuttle runtime"); + } else { + // If the version of cargo-shuttle is different from shuttle-runtime, + // or it isn't installed, try to install shuttle-runtime from crates.io. + if let Err(err) = check_version(&runtime_path) { + warn!("{}", err); + + trace!("installing shuttle-runtime"); + std::process::Command::new("cargo") + .arg("install") + .arg("shuttle-runtime") + .arg("--bin") + .arg("shuttle-next") + .arg("--features") + .arg("next") + .output() + .expect("failed to install the shuttle runtime"); + }; + }; + + runtime_path + } else { + trace!(path = ?executable_path, "using alpha runtime"); + executable_path.clone() + } + }; + + let (mut runtime, mut runtime_client) = runtime::start( + is_wasm, + runtime::StorageManagerType::WorkingDir(working_directory.to_path_buf()), + &format!("http://localhost:{provisioner_port}"), + None, + run_args.port - (1 + i), + runtime_path, + ) + .await + .map_err(|err| { + provisioner_server.abort(); + err + })?; + + let service_name = service.service_name()?; + let load_request = tonic::Request::new(LoadRequest { + path: executable_path + .into_os_string() + .into_string() + .expect("to convert path to string"), + service_name: service_name.to_string(), + resources: Default::default(), + secrets, + }); + + trace!("loading service"); + let response = runtime_client + .load(load_request) + .or_else(|err| async { + provisioner_server.abort(); + runtime.kill().await?; + Err(err) + }) + .await? + .into_inner(); + + if !response.success { + error!(error = response.message, "failed to load your service"); + provisioner_server.abort(); + runtime.kill().await?; + exit(1); + } + + let resources = response + .resources + .into_iter() + .map(resource::Response::from_bytes) + .collect(); + + println!("{}", get_resources_table(&resources, service_name.as_str())); + + let mut stream = runtime_client + .subscribe_logs(tonic::Request::new(SubscribeLogsRequest {})) + .or_else(|err| async { + provisioner_server.abort(); + runtime.kill().await?; + Err(err) + }) + .await? + .into_inner(); + + tokio::spawn(async move { + while let Ok(Some(log)) = stream.message().await { + let log: shuttle_common::LogItem = log.try_into().expect("to convert log"); + println!("{log}"); + } + }); + + let addr = SocketAddr::new( + if run_args.external { + Ipv4Addr::new(0, 0, 0, 0) + } else { + Ipv4Addr::LOCALHOST + } + .into(), + run_args.port + i, + ); + + println!( + " {} {} on http://{}\n", + "Starting".bold().green(), + service_name, + addr + ); + + let start_request = StartRequest { + ip: addr.to_string(), + }; + + trace!(?start_request, "starting service"); + let response = runtime_client + .start(tonic::Request::new(start_request)) + .or_else(|err| async { + provisioner_server.abort(); + runtime.kill().await?; + Err(err) + }) + .await? + .into_inner(); + + trace!(response = ?response, "client response: "); + Ok((runtime, runtime_client)) + } + async fn local_run(&self, run_args: RunArgs) -> Result<()> { trace!("starting a local run for a service: {run_args:?}"); @@ -487,201 +684,175 @@ impl Shuttle { // Compile all the alpha or shuttle-next services in the workspace. let services = build_workspace(working_directory, run_args.release, tx).await?; - let mut runtime_handles = JoinSet::new(); + // TODO: figure out how best to handle the runtime handles, and what to do if + // one completes. + let (mut sigterm_notif, mut sigint_notif) = if cfg!(target_family = "unix") { + ( + Some( + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("Can not get the SIGTERM signal receptor"), + ), + Some( + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) + .expect("Can not get the SIGINT signal receptor"), + ), + ) + } else { + (None, None) + }; // Start all the services. + let mut runtimes: Vec<( + Child, + RuntimeClient>>, + )> = Vec::new(); + let mut signal_received = false; for (i, service) in services.iter().enumerate() { - let BuiltService { - executable_path, - is_wasm, - working_directory, - .. - } = service.clone(); - - trace!("loading secrets"); - let secrets_path = if working_directory.join("Secrets.dev.toml").exists() { - working_directory.join("Secrets.dev.toml") - } else { - working_directory.join("Secrets.toml") - }; - - let secrets: HashMap = - if let Ok(secrets_str) = read_to_string(secrets_path) { - let secrets: HashMap = - secrets_str.parse::()?.try_into()?; - - trace!(keys = ?secrets.keys(), "available secrets"); - - secrets - } else { - trace!("no Secrets.toml was found"); - Default::default() + // We must cover the case of starting multiple workspace services and receiving a signal in parallel. + // This must stop all the existing runtimes and stop creating new ones. + if cfg!(target_family = "unix") { + let sigterm = sigterm_notif.as_mut().expect("SIGTERM reactor failure"); + let sigint = sigint_notif.as_mut().expect("SIGINT reactor failure"); + signal_received = tokio::select! { + runtime_info = self.spin_local_runtime(&run_args, service, &provisioner_server, i as u16, provisioner_port) => { + runtimes.push(runtime_info.unwrap()); + false + }, + _ = sigterm.recv() => { + println!( + "cargo-shuttle received SIGTERM. Killing all the runtimes..." + ); + true + }, + _ = sigint.recv() => { + println!( + "cargo-shuttle received SIGINT. Killing all the runtimes..." + ); + true + } }; - let runtime_path = || { - if is_wasm { - let runtime_path = home::cargo_home() - .expect("failed to find cargo home dir") - .join("bin/shuttle-next"); - - println!("Installing shuttle-next runtime. This can take a while..."); - - if cfg!(debug_assertions) { - // Canonicalized path to shuttle-runtime for dev to work on windows - - let path = std::fs::canonicalize(format!("{MANIFEST_DIR}/../runtime")) - .expect("path to shuttle-runtime does not exist or is invalid"); - - std::process::Command::new("cargo") - .arg("install") - .arg("shuttle-runtime") - .arg("--path") - .arg(path) - .arg("--bin") - .arg("shuttle-next") - .arg("--features") - .arg("next") - .output() - .expect("failed to install the shuttle runtime"); - } else { - // If the version of cargo-shuttle is different from shuttle-runtime, - // or it isn't installed, try to install shuttle-runtime from crates.io. - if let Err(err) = check_version(&runtime_path) { - warn!(error = ?err, "failed to check installed runtime version"); - - trace!("installing shuttle-runtime"); - std::process::Command::new("cargo") - .arg("install") - .arg("shuttle-runtime") - .arg("--bin") - .arg("shuttle-next") - .arg("--features") - .arg("next") - .output() - .expect("failed to install the shuttle runtime"); - }; - }; - - runtime_path - } else { - trace!(path = ?executable_path, "using alpha runtime"); - executable_path.clone() + if signal_received { + break; } - }; + } else { + runtimes.push( + self.spin_local_runtime( + &run_args, + service, + &provisioner_server, + i as u16, + provisioner_port, + ) + .await?, + ); + } + } - let (mut runtime, mut runtime_client) = runtime::start( - is_wasm, - runtime::StorageManagerType::WorkingDir(working_directory.to_path_buf()), - &format!("http://localhost:{provisioner_port}"), - None, - run_args.port - (1 + i) as u16, - runtime_path, - ) - .await - .map_err(|err| { - provisioner_server.abort(); - err - })?; - let service_name = service.service_name()?; - let load_request = tonic::Request::new(LoadRequest { - path: executable_path - .into_os_string() - .into_string() - .expect("to convert path to string"), - service_name: service_name.to_string(), - resources: Default::default(), - secrets, - }); - - trace!("loading service"); - let response = runtime_client - .load(load_request) - .or_else(|err| async { - provisioner_server.abort(); - runtime.kill().await?; - - Err(err) - }) - .await? - .into_inner(); - - if !response.success { - error!(error = response.message, "failed to load your service"); - provisioner_server.abort(); - runtime.kill().await?; - runtime_handles.abort_all(); - exit(1); + // If prior signal received is set to true we must stop all the existing runtimes and + // exit the `local_run`. + if signal_received { + for mut rt in runtimes { + let stop_request = StopRequest {}; + trace!( + ?stop_request, + "stopping service because it received a signal" + ); + let _response = + rt.1.stop(tonic::Request::new(stop_request)) + .or_else(|err| async { + provisioner_server.abort(); + rt.0.kill().await?; + Err(err) + }) + .await? + .into_inner(); } - let resources = response - .resources - .into_iter() - .map(resource::Response::from_bytes) - .collect(); - - println!("{}", get_resources_table(&resources, service_name.as_str())); - - let mut stream = runtime_client - .subscribe_logs(tonic::Request::new(SubscribeLogsRequest {})) - .or_else(|err| async { - provisioner_server.abort(); - runtime.kill().await?; - - Err(err) - }) - .await? - .into_inner(); - - tokio::spawn(async move { - while let Ok(Some(log)) = stream.message().await { - let log: shuttle_common::LogItem = log.try_into().expect("to convert log"); - println!("{log}"); - } - }); + return Ok(()); + } - let addr = SocketAddr::new( - if run_args.external { - Ipv4Addr::new(0, 0, 0, 0) - } else { - Ipv4Addr::LOCALHOST + // If no signal was received during runtimes initialization, then we must handle each runtime until + // comletion and handle the signals during this time. + if cfg!(target_family = "unix") { + let sigterm = sigterm_notif.as_mut().expect("SIGTERM reactor failure"); + let sigint = sigint_notif.as_mut().expect("SIGINT reactor failure"); + for mut rt in runtimes { + // If we received a signal while waiting for any runtime we must stop the rest and exit + // the waiting loop. + if signal_received { + let stop_request = StopRequest {}; + trace!( + ?stop_request, + "stopping service because it received a signal" + ); + let _response = + rt.1.stop(tonic::Request::new(stop_request)) + .or_else(|err| async { + provisioner_server.abort(); + rt.0.kill().await?; + Err(err) + }) + .await? + .into_inner(); + continue; } - .into(), - run_args.port + i as u16, - ); - - println!( - " {} {} on http://{}\n", - "Starting".bold().green(), - service_name, - addr - ); - let start_request = StartRequest { - ip: addr.to_string(), - }; - - trace!(?start_request, "starting service"); - let response = runtime_client - .start(tonic::Request::new(start_request)) - .or_else(|err| async { - provisioner_server.abort(); - runtime.kill().await?; - Err(err) - }) - .await? - .into_inner(); - - trace!(response = ?response, "client response: "); - runtime_handles.spawn(async move { runtime.wait().await }); - } - - // TODO: figure out how best to handle the runtime handles, and what to do if - // one completes. - while let Some(res) = runtime_handles.join_next().await { - println!( - "a service future completed with exit status: {:?}", - res.unwrap().unwrap().code() - ); + // Receiving a signal will stop the current runtime we're waiting for. + signal_received = tokio::select! { + res = rt.0.wait() => { + println!( + "a service future completed with exit status: {:?}", + res.unwrap().code() + ); + false + }, + _ = sigterm.recv() => { + println!( + "cargo-shuttle received SIGTERM. Killing all the runtimes..." + ); + + let stop_request = StopRequest {}; + trace!(?stop_request, "stopping service because of SIGTERM"); + let _response = rt.1 + .stop(tonic::Request::new(stop_request)) + .or_else(|err| async { + provisioner_server.abort(); + rt.0.kill().await?; + Err(err) + }) + .await? + .into_inner(); + true + }, + _ = sigint.recv() => { + println!( + "cargo-shuttle received SIGINT. Killing all the runtimes..." + ); + let stop_request = StopRequest {}; + trace!(?stop_request, "stopping service because of SIGINT"); + let _response = rt.1 + .stop(tonic::Request::new(stop_request)) + .or_else(|err| async { + provisioner_server.abort(); + rt.0.kill().await?; + Err(err) + }) + .await? + .into_inner(); + true + } + }; + } + } else { + // In case we're not on an unix family OS, we're simply waiting for the runtimes + // to end. + for (mut rt, _) in runtimes { + println!( + "a service future completed with exit status: {:?}", + rt.wait().await?.code() + ); + } } Ok(()) diff --git a/runtime/src/alpha/mod.rs b/runtime/src/alpha/mod.rs index 2c90124dea..742d42702e 100644 --- a/runtime/src/alpha/mod.rs +++ b/runtime/src/alpha/mod.rs @@ -315,6 +315,21 @@ where tokio::spawn(async move { let mut background = handle.spawn(service.bind(service_address)); + let (_sigterm_notif, _sigint_notif) = if cfg!(target_family = "unix") { + ( + Some( + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("Can not get the SIGTERM signal receptor"), + ), + Some( + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) + .expect("Can not get the SIGINT signal receptor"), + ), + ) + } else { + (None, None) + }; + tokio::select! { res = &mut background => { match res { @@ -354,7 +369,7 @@ where info!("will now abort the service"); background.abort(); background.await.unwrap().expect("to stop service"); - } + }, } });