diff --git a/Cargo.toml b/Cargo.toml index 171af9c..d4e6263 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ required-features = ["agent"] env_logger = "0.11.0" rand = "0.8.5" rsa = { version = "0.9.6", features = ["sha2", "sha1"] } -tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread", "time"] } sha1 = { version = "0.10.5", default-features = false, features = ["oid"] } testresult = "0.4.0" hex-literal = "0.4.1" diff --git a/examples/ssh-agent-client.rs b/examples/ssh-agent-client.rs new file mode 100644 index 0000000..0230041 --- /dev/null +++ b/examples/ssh-agent-client.rs @@ -0,0 +1,40 @@ +use ssh_agent_lib::agent::Session; +use ssh_agent_lib::client::Client; +#[cfg(windows)] +use tokio::net::windows::named_pipe::ClientOptions; +#[cfg(unix)] +use tokio::net::UnixStream; + +#[tokio::main] +async fn main() -> Result<(), Box> { + #[cfg(unix)] + let mut client = { + let stream = UnixStream::connect(std::env::var("SSH_AUTH_SOCK")?).await?; + Client::new(stream) + }; + #[cfg(windows)] + let mut client = { + let stream = loop { + // https://docs.rs/windows-sys/latest/windows_sys/Win32/Foundation/constant.ERROR_PIPE_BUSY.html + const ERROR_PIPE_BUSY: u32 = 231u32; + + // correct way to do it taken from + // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html + match ClientOptions::new().open(std::env::var("SSH_AUTH_SOCK")?) { + Ok(client) => break client, + Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (), + Err(e) => Err(e)?, + } + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + }; + Client::new(stream) + }; + + eprintln!( + "Identities that this agent knows of: {:#?}", + client.request_identities().await? + ); + + Ok(()) +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..f34672c --- /dev/null +++ b/src/client.rs @@ -0,0 +1,170 @@ +use std::fmt; + +use futures::{SinkExt, TryStreamExt}; +use ssh_key::Signature; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +use crate::{ + codec::Codec, + proto::{ + AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained, Extension, Identity, + ProtoError, RemoveIdentity, Request, Response, SignRequest, SmartcardKey, + }, +}; + +#[derive(Debug)] +pub struct Client +where + Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + adapter: Framed>, +} + +impl Client +where + Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + pub fn new(socket: Stream) -> Self { + let adapter = Framed::new(socket, Codec::default()); + Self { adapter } + } +} + +#[async_trait::async_trait] +impl crate::agent::Session for Client +where + Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, +{ + async fn request_identities(&mut self) -> Result, Box> { + if let Response::IdentitiesAnswer(identities) = + self.handle(Request::RequestIdentities).await? + { + Ok(identities) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn sign( + &mut self, + request: SignRequest, + ) -> Result> { + if let Response::SignResponse(response) = self.handle(Request::SignRequest(request)).await? + { + Ok(response) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn add_identity( + &mut self, + identity: AddIdentity, + ) -> Result<(), Box> { + if let Response::Success = self.handle(Request::AddIdentity(identity)).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn add_identity_constrained( + &mut self, + identity: AddIdentityConstrained, + ) -> Result<(), Box> { + if let Response::Success = self.handle(Request::AddIdConstrained(identity)).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn remove_identity( + &mut self, + identity: RemoveIdentity, + ) -> Result<(), Box> { + if let Response::Success = self.handle(Request::RemoveIdentity(identity)).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn remove_all_identities(&mut self) -> Result<(), Box> { + if let Response::Success = self.handle(Request::RemoveAllIdentities).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn add_smartcard_key( + &mut self, + key: SmartcardKey, + ) -> Result<(), Box> { + if let Response::Success = self.handle(Request::AddSmartcardKey(key)).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn add_smartcard_key_constrained( + &mut self, + key: AddSmartcardKeyConstrained, + ) -> Result<(), Box> { + if let Response::Success = self + .handle(Request::AddSmartcardKeyConstrained(key)) + .await? + { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn remove_smartcard_key( + &mut self, + key: SmartcardKey, + ) -> Result<(), Box> { + if let Response::Success = self.handle(Request::RemoveSmartcardKey(key)).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn lock(&mut self, key: String) -> Result<(), Box> { + if let Response::Success = self.handle(Request::Lock(key)).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn unlock(&mut self, key: String) -> Result<(), Box> { + if let Response::Success = self.handle(Request::Unlock(key)).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn extension(&mut self, extension: Extension) -> Result<(), Box> { + if let Response::Success = self.handle(Request::Extension(extension)).await? { + Ok(()) + } else { + Err(ProtoError::UnexpectedResponse.into()) + } + } + + async fn handle(&mut self, message: Request) -> Result> { + self.adapter.send(message).await?; + if let Some(response) = self.adapter.try_next().await? { + Ok(response) + } else { + Err(ProtoError::IO(std::io::Error::other("server disconnected")).into()) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 1a9528d..4a5695a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,8 @@ pub mod proto; #[cfg(feature = "agent")] pub mod agent; +#[cfg(feature = "agent")] +pub mod client; #[cfg(feature = "codec")] pub mod codec; pub mod error; diff --git a/src/proto/error.rs b/src/proto/error.rs index f99fffb..d19cb5d 100644 --- a/src/proto/error.rs +++ b/src/proto/error.rs @@ -14,6 +14,8 @@ pub enum ProtoError { SshKey(#[from] ssh_key::Error), #[error("Command not supported ({command})")] UnsupportedCommand { command: u8 }, + #[error("Unexpected response received")] + UnexpectedResponse, } pub type ProtoResult = Result;