-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
141 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,5 @@ | ||
use color_eyre::eyre; | ||
use std::path::PathBuf; | ||
|
||
// async fn open_ssh_tunnel( | ||
// username: impl AsRef<str>, | ||
// password: impl AsRef<str>, | ||
// local_port: impl Into<Option<u16>>, | ||
// ) -> eyre::Result< | ||
// ( | ||
// std::net::SocketAddr, | ||
// tokio::sync::oneshot::Receiver<ssh_jumper::model::SshForwarderEnd>, | ||
// ), | ||
// ssh_jumper::model::Error, | ||
// > { | ||
// use ssh_jumper::{ | ||
// model::{HostAddress, HostSocketParams, JumpHostAuthParams, SshTunnelParams}, | ||
// SshJumper, | ||
// }; | ||
// use std::borrow::Cow; | ||
// | ||
// // Similar to running: | ||
// // ssh -i ~/.ssh/id_rsa -L 1234:target_host:8080 [email protected] | ||
// let jump_host = HostAddress::HostName(Cow::Borrowed("bastion.com")); | ||
// let jump_host_auth_params = JumpHostAuthParams::password( | ||
// username.as_ref().into(), // Cow::Borrowed("my_user"), | ||
// password.as_ref().into(), // Cow::Borrowed("my_user"), | ||
// // Cow::Borrowed(Path::new("~/.ssh/id_rsa")), | ||
// ); | ||
// let target_socket = HostSocketParams { | ||
// address: HostAddress::HostName(Cow::Borrowed("target_host")), | ||
// port: 8080, | ||
// }; | ||
// let mut ssh_params = SshTunnelParams::new(jump_host, jump_host_auth_params, target_socket); | ||
// if let Some(local_port) = local_port.into() { | ||
// // os will allocate a port if this is left out | ||
// ssh_params = ssh_params.with_local_port(local_port); | ||
// } | ||
// | ||
// let tunnel = SshJumper::open_tunnel(&ssh_params).await?; | ||
// Ok(tunnel) | ||
// } | ||
// | ||
// /// Connect to remote host. | ||
// /// | ||
// /// # Errors | ||
// /// If connection fails. | ||
// pub async fn connect() -> eyre::Result<()> { | ||
// let ssh_username = std::env::var("ssh_user_name")?; | ||
// let ssh_password = std::env::var("ssh_password")?; | ||
// | ||
// let (_local_socket_addr, _ssh_forwarder_end_rx) = | ||
// open_ssh_tunnel(ssh_username, ssh_password, None).await?; | ||
// Ok(()) | ||
// } | ||
use std::path::Path; | ||
|
||
type AsyncSession = async_ssh2_lite::AsyncSession<async_ssh2_lite::TokioTcpStream>; | ||
|
||
|
@@ -89,6 +37,53 @@ impl SSHClient { | |
pub trait Remote { | ||
fn username(&self) -> &str; | ||
|
||
/// Wait for file to become available at remote path. | ||
async fn wait_for_file( | ||
&self, | ||
remote_path: &Path, | ||
interval: std::time::Duration, | ||
allow_empty: bool, | ||
attempts: Option<usize>, | ||
) -> eyre::Result<()> { | ||
let attempts = attempts.unwrap_or(1); | ||
let mut interval = tokio::time::interval(interval); | ||
|
||
for attempt in 1..=attempts { | ||
if attempt > 3 { | ||
log::warn!( | ||
"reading from {} (attempt {}/{})", | ||
remote_path.display(), | ||
attempt, | ||
attempts | ||
); | ||
} | ||
let cmd = format!(r#"stat -c "%s" {}"#, remote_path.display()); | ||
match self.run_command(&cmd).await { | ||
Err(err) => { | ||
log::warn!("failed to execute command {:?}: {}", &cmd, err); | ||
} | ||
Ok((exit_status, _, stderr)) if exit_status != 0 => { | ||
log::warn!( | ||
"command {:?} failed with exit code {}: {}", | ||
cmd, | ||
exit_status, | ||
stderr, | ||
); | ||
} | ||
Ok((_, stdout, _)) => { | ||
if allow_empty || stdout.parse::<usize>().unwrap_or(0) > 0 { | ||
return Ok(()); | ||
} | ||
} | ||
}; | ||
interval.tick().await; | ||
} | ||
Err(eyre::eyre!( | ||
"{} does not exist or is empty", | ||
remote_path.display() | ||
)) | ||
} | ||
|
||
async fn run_command( | ||
&self, | ||
command: impl AsRef<str> + Send + Sync, | ||
|
@@ -124,42 +119,14 @@ impl Remote for SSHClient { | |
stderr.read_to_string(&mut stderr_buffer), | ||
); | ||
let exit_status = channel.exit_status()?; | ||
Ok((exit_status, stdout_buffer, stderr_buffer)) | ||
Ok(( | ||
exit_status, | ||
stdout_buffer.trim().to_string(), | ||
stderr_buffer.trim().to_string(), | ||
)) | ||
} | ||
} | ||
|
||
// #[derive()] | ||
// pub struct DAS<R> | ||
// where | ||
// R: Remote + slurm::Client, | ||
// { | ||
// pub remote: R, | ||
// pub cuda_module: String, | ||
// pub remote_scratch_dir: PathBuf, | ||
// } | ||
// | ||
// impl<R> DAS<R> | ||
// where | ||
// R: Remote + slurm::Client, | ||
// { | ||
// pub fn new(remote: R) -> Self { | ||
// Self { | ||
// remote, | ||
// cuda_module: "".to_string(), | ||
// remote_scratch_dir: PathBuf::from("/var/scratch"), | ||
// } | ||
// } | ||
// // pub async fn connect<A>( | ||
// // address: A, | ||
// // username: impl AsRef<str>, | ||
// // password: impl AsRef<str>, | ||
// // ) -> eyre::Result<Self> | ||
// // where | ||
// // A: Into<std::net::SocketAddr>, | ||
// // { | ||
// // } | ||
// } | ||
|
||
pub mod slurm { | ||
use color_eyre::eyre; | ||
use itertools::Itertools; | ||
|
@@ -208,15 +175,6 @@ pub mod slurm { | |
confidence: Option<usize>, | ||
) -> eyre::Result<()>; | ||
|
||
/// Wait for file to become available at remote path. | ||
async fn wait_for_file( | ||
&self, | ||
remote_path: &Path, | ||
interval: std::time::Duration, | ||
allow_empty: bool, | ||
attempts: Option<usize>, | ||
) -> eyre::Result<()>; | ||
|
||
/// Submit job | ||
async fn submit_job(&self, job_path: &Path) -> eyre::Result<usize>; | ||
} | ||
|
@@ -240,7 +198,9 @@ pub mod slurm { | |
} | ||
let cmd = cmd.join(" "); | ||
let (exit_status, stdout, stderr) = self.run_command(&cmd).await?; | ||
log::error!("{}", stderr); | ||
if !stderr.is_empty() { | ||
log::error!("{}", stderr); | ||
} | ||
if exit_status != 0 { | ||
eyre::bail!("{} failed with exit code {}", cmd, exit_status); | ||
} | ||
|
@@ -268,7 +228,9 @@ pub mod slurm { | |
} | ||
let cmd = cmd.join(" "); | ||
let (exit_status, stdout, stderr) = self.run_command(&cmd).await?; | ||
log::error!("{}", stderr); | ||
if !stderr.is_empty() { | ||
log::error!("{}", stderr); | ||
} | ||
if exit_status != 0 { | ||
eyre::bail!("{} failed with exit code {}", cmd, exit_status); | ||
} | ||
|
@@ -289,7 +251,9 @@ pub mod slurm { | |
let cmd = vec!["squeue", r#"--format="%i""#, "--name", name.as_ref()]; | ||
let cmd = cmd.join(" "); | ||
let (exit_status, stdout, stderr) = self.run_command(&cmd).await?; | ||
log::error!("{}", stderr); | ||
if !stderr.is_empty() { | ||
log::error!("{}", stderr); | ||
} | ||
if exit_status != 0 { | ||
eyre::bail!("{} failed with exit code {}", cmd, exit_status); | ||
} | ||
|
@@ -310,8 +274,12 @@ pub mod slurm { | |
} | ||
let cmd = cmd.join(" "); | ||
let (exit_status, stdout, stderr) = self.run_command(&cmd).await?; | ||
log::debug!("{}", stdout); | ||
log::error!("{}", stderr); | ||
if !stdout.is_empty() { | ||
log::debug!("{}", stdout); | ||
} | ||
if !stderr.is_empty() { | ||
log::error!("{}", stderr); | ||
} | ||
if exit_status != 0 { | ||
eyre::bail!("{} failed with exit code {}", cmd, exit_status); | ||
} | ||
|
@@ -352,58 +320,16 @@ pub mod slurm { | |
Ok(()) | ||
} | ||
|
||
async fn wait_for_file( | ||
&self, | ||
remote_path: &Path, | ||
interval: std::time::Duration, | ||
allow_empty: bool, | ||
attempts: Option<usize>, | ||
) -> eyre::Result<()> { | ||
let attempts = attempts.unwrap_or(1); | ||
let mut interval = tokio::time::interval(interval); | ||
|
||
for attempt in 1..=attempts { | ||
if attempt > 3 { | ||
log::warn!( | ||
"reading from {} (attempt {}/{})", | ||
remote_path.display(), | ||
attempt, | ||
attempts | ||
); | ||
} | ||
let cmd = format!(r#"stat -c "%s" {}"#, remote_path.display()); | ||
match self.run_command(&cmd).await { | ||
Err(err) => { | ||
log::warn!("failed to execute command {:?}: {}", &cmd, err); | ||
} | ||
Ok((exit_status, _, stderr)) if exit_status != 0 => { | ||
log::warn!( | ||
"command {:?} failed with exit code {}: {}", | ||
cmd, | ||
exit_status, | ||
stderr, | ||
); | ||
} | ||
Ok((_, stdout, _)) => { | ||
if allow_empty || stdout.parse::<usize>().unwrap_or(0) > 0 { | ||
return Ok(()); | ||
} | ||
} | ||
}; | ||
interval.tick().await; | ||
} | ||
Err(eyre::eyre!( | ||
"{} does not exist or is empty", | ||
remote_path.display() | ||
)) | ||
} | ||
|
||
/// Submit job | ||
async fn submit_job(&self, remote_job_path: &Path) -> eyre::Result<usize> { | ||
let cmd = format!("sbatch {}", remote_job_path.display()); | ||
let (exit_status, stdout, stderr) = self.run_command(&cmd).await?; | ||
log::debug!("{}", stderr); | ||
log::error!("{}", stderr); | ||
if !stdout.is_empty() { | ||
log::debug!("{}", stdout); | ||
} | ||
if !stderr.is_empty() { | ||
log::error!("{}", stderr); | ||
} | ||
if exit_status != 0 { | ||
eyre::bail!("{} failed with exit code {}", cmd, exit_status); | ||
} | ||
|
@@ -448,10 +374,7 @@ pub mod scp { | |
use std::path::Path; | ||
|
||
#[async_trait::async_trait] | ||
pub trait Client // pub trait Client<C> | ||
// where | ||
// C: tokio::io::AsyncRead, | ||
{ | ||
pub trait Client { | ||
async fn upload_streamed<R>( | ||
&self, | ||
remote_path: impl AsRef<Path> + Send + Sync, | ||
|
@@ -479,7 +402,6 @@ pub mod scp { | |
} | ||
|
||
#[async_trait::async_trait] | ||
// impl Client<async_ssh2_lite::AsyncChannel<async_ssh2_lite::TokioTcpStream>> for crate::SSHClient { | ||
impl Client for crate::SSHClient { | ||
async fn upload_streamed<R>( | ||
&self, | ||
|
@@ -522,13 +444,13 @@ pub mod scp { | |
remote_path: impl AsRef<Path> + Send, | ||
) -> eyre::Result<( | ||
Box<dyn tokio::io::AsyncRead + Unpin + Send>, | ||
// async_ssh2_lite::AsyncChannel<async_ssh2_lite::TokioTcpStream>, | ||
ssh2::ScpFileStat, | ||
)> { | ||
let remote_path = remote_path.as_ref(); | ||
let (channel, stat) = self.session.scp_recv(remote_path.as_ref()).await?; | ||
log::debug!( | ||
"scp: download {} (mode {})", | ||
"downloading {} ({}, mode {})", | ||
remote_path.display(), | ||
human_bytes::human_bytes(stat.size() as f64), | ||
stat.mode() | ||
); | ||
|
Oops, something went wrong.