Skip to content

Commit

Permalink
rust: add supervisor function to forward signals to child processes
Browse files Browse the repository at this point in the history
  • Loading branch information
rizsotto committed Jan 18, 2025
1 parent 8e9526c commit 7840e51
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 13 deletions.
9 changes: 9 additions & 0 deletions rust/bear/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ nom.workspace = true
regex.workspace = true
rand.workspace = true
tempfile.workspace = true
nix = { version = "0.29", optional = true, features = ["signal", "process"] }
winapi = { version = "0.3", optional = true, features = ["processthreadsapi", "winnt", "handleapi"] }
signal-hook = "0.3.17"

[target.'cfg(unix)'.dependencies]
nix = { version = "0.29", features = ["signal", "process"] }

[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["processthreadsapi", "winnt", "handleapi"] }

[profile.release]
strip = true
Expand Down
11 changes: 5 additions & 6 deletions rust/bear/src/bin/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
extern crate core;

use anyhow::{Context, Result};
use bear::intercept::supervise::supervise;
use bear::intercept::tcp::ReporterOnTcp;
use bear::intercept::Reporter;
use bear::intercept::KEY_DESTINATION;
Expand Down Expand Up @@ -46,13 +47,11 @@ fn main() -> Result<()> {
}

// Execute the real executable with the same arguments
// TODO: handle signals and forward them to the child process.
let status = std::process::Command::new(real_executable)
.args(std::env::args().skip(1))
.status()?;
log::info!("Execution finished with status: {:?}", status);
let mut command = std::process::Command::new(real_executable);
let exit_status = supervise(command.args(std::env::args().skip(1)))?;
log::info!("Execution finished with status: {:?}", exit_status);
// Return the child process status code
std::process::exit(status.code().unwrap_or(1));
std::process::exit(exit_status.code().unwrap_or(1));
}

/// Get the file name of the executable from the arguments.
Expand Down
16 changes: 9 additions & 7 deletions rust/bear/src/intercept/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//! The module provides abstractions for the reporter and the collector. And it also defines
//! the data structures that are used to represent the events.
use crate::intercept::supervise::supervise;
use crate::{args, config};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand All @@ -19,6 +20,7 @@ use std::sync::Arc;
use std::{env, fmt, thread};

pub mod persistence;
pub mod supervise;
pub mod tcp;

/// Declare the environment variables used by the intercept mode.
Expand Down Expand Up @@ -267,17 +269,17 @@ impl InterceptEnvironment {
// TODO: record the execution of the build command

let environment = self.environment();
let mut child = Command::new(input.arguments[0].clone())
.args(input.arguments[1..].iter())
.envs(environment)
.spawn()?;
let process = input.arguments[0].clone();
let arguments = input.arguments[1..].to_vec();

// TODO: forward signals to the child process
let result = child.wait()?;
let mut child = Command::new(process);

let exit_status = supervise(child.args(arguments).envs(environment))?;
log::info!("Execution finished with status: {:?}", exit_status);

// The exit code is not always available. When the process is killed by a signal,
// the exit code is not available. In this case, we return the `FAILURE` exit code.
let exit_code = result
let exit_code = exit_status
.code()
.map(|code| ExitCode::from(code as u8))
.unwrap_or(ExitCode::FAILURE);
Expand Down
118 changes: 118 additions & 0 deletions rust/bear/src/intercept/supervise.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// SPDX-License-Identifier: GPL-3.0-or-later

use anyhow::Result;
use nix::libc::c_int;
#[cfg(unix)]
use nix::sys::signal::{kill, Signal};
#[cfg(unix)]
use nix::unistd::Pid;
use std::process::{Command, ExitStatus};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[cfg(windows)]
use winapi::shared::minwindef::FALSE;
#[cfg(windows)]
use winapi::um::processthreadsapi::{OpenProcess, TerminateProcess};
#[cfg(windows)]
use winapi::um::winnt::{PROCESS_TERMINATE, SYNCHRONIZE};

/// This method supervises the execution of a command.
///
/// It starts the command and waits for its completion. It also forwards
/// signals to the child process. The method returns the exit status of the
/// child process.
pub fn supervise(command: &mut Command) -> Result<ExitStatus> {
let mut child = command.spawn()?;

let child_pid = child.id();
let running = Arc::new(AtomicBool::new(true));
let running_in_thread = running.clone();

let mut signals = signal_hook::iterator::Signals::new([
signal_hook::consts::SIGINT,
signal_hook::consts::SIGTERM,
])?;

#[cfg(unix)]
{
signals.add_signal(signal_hook::consts::SIGHUP)?;
signals.add_signal(signal_hook::consts::SIGQUIT)?;
signals.add_signal(signal_hook::consts::SIGALRM)?;
signals.add_signal(signal_hook::consts::SIGUSR1)?;
signals.add_signal(signal_hook::consts::SIGUSR2)?;
signals.add_signal(signal_hook::consts::SIGCONT)?;
signals.add_signal(signal_hook::consts::SIGSTOP)?;
}

let handler = thread::spawn(move || {
for signal in signals.forever() {
log::debug!("Received signal: {:?}", signal);
if forward_signal(signal, child_pid) {
// If the signal caused termination, we should stop the process.
running_in_thread.store(false, Ordering::SeqCst);
break;
}
}
});

while running.load(Ordering::SeqCst) {
thread::sleep(Duration::from_millis(100));
}
handler.join().unwrap();

let exit_status = child.wait()?;

Ok(exit_status)
}

#[cfg(windows)]
fn forward_signal(_: c_int, child_pid: u32) -> bool {
let process_handle = unsafe { OpenProcess(PROCESS_TERMINATE | SYNCHRONIZE, FALSE, child_pid) };
if process_handle.is_null() {
let err = unsafe { winapi::um::errhandling::GetLastError() };
log::error!("Failed to open process: {}", err);
// If the process handle is not valid, presume the process is not running anymore.
return true;
}

let terminated = unsafe { TerminateProcess(process_handle, 1) };
if terminated == FALSE {
let err = unsafe { winapi::um::errhandling::GetLastError() };
log::error!("Failed to terminate process: {}", err);
}

// Ensure proper handle closure
unsafe { winapi::um::handleapi::CloseHandle(process_handle) };

// Return true if the process was terminated.
terminated == TRUE
}

#[cfg(unix)]
fn forward_signal(signal: c_int, child_pid: u32) -> bool {
// Forward the signal to the child process
if let Err(e) = kill(
Pid::from_raw(child_pid as i32),
Signal::try_from(signal).ok(),
) {
log::error!("Error forwarding signal: {}", e);
}

// Return true if the process was terminated.
match kill(Pid::from_raw(child_pid as i32), None) {
Ok(_) => {
log::debug!("Checking if the process is still running... yes");
false
}
Err(nix::Error::ESRCH) => {
log::debug!("Checking if the process is still running... no");
true
}
Err(_) => {
log::debug!("Checking if the process is still running... presume dead");
true
}
}
}

0 comments on commit 7840e51

Please sign in to comment.