Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSH agent client with an example #44

Merged
merged 3 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ byteorder = "1.4.3"
async-trait = { version = "0.1.77", optional = true }
futures = { version = "0.3.30", optional = true }
log = { version = "0.4.6", optional = true }
tokio = { version = "1", optional = true, features = ["rt", "net"] }
tokio = { version = "1", optional = true, features = ["rt", "net", "time"] }
tokio-util = { version = "0.7.1", optional = true, features = ["codec"] }
service-binding = { version = "^2" }
service-binding = { version = "^2.1" }
ssh-encoding = { version = "0.2.0" }
ssh-key = { version = "0.6.6", features = ["rsa", "alloc"] }
thiserror = "1.0.58"
Expand Down
20 changes: 20 additions & 0 deletions examples/ssh-agent-client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use service_binding::Binding;
use ssh_agent_lib::client::connect;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(unix)]
let mut client =
connect(Binding::FilePath(std::env::var("SSH_AUTH_SOCK")?.into()).try_into()?).await?;

#[cfg(windows)]
let mut client =
connect(Binding::NamedPipe(std::env::var("SSH_AUTH_SOCK")?.into()).try_into()?).await?;

eprintln!(
"Identities that this agent knows of: {:#?}",
client.request_identities().await?
);

Ok(())
}
63 changes: 31 additions & 32 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl ListeningSocket for NamedPipeListener {
/// This type is implemented by agents that want to handle incoming SSH agent
/// connections.
#[async_trait]
pub trait Session: 'static + Sync + Send + Sized {
pub trait Session: 'static + Sync + Send + Unpin {
/// Request a list of keys managed by this session.
async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
Err(AgentError::from(ProtoError::UnsupportedCommand {
Expand Down Expand Up @@ -215,37 +215,36 @@ pub trait Session: 'static + Sync + Send + Sized {
}
Ok(Response::Success)
}
}

#[doc(hidden)]
async fn handle_socket<S>(
&mut self,
mut adapter: Framed<S::Stream, Codec<Request, Response>>,
) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
loop {
if let Some(incoming_message) = adapter.try_next().await? {
log::debug!("Request: {incoming_message:?}");
let response = match self.handle(incoming_message).await {
Ok(message) => message,
Err(AgentError::ExtensionFailure) => {
log::error!("Extension failure handling message");
Response::ExtensionFailure
}
Err(e) => {
log::error!("Error handling message: {:?}", e);
Response::Failure
}
};
log::debug!("Response: {response:?}");
async fn handle_socket<S>(
mut session: impl Session,
mut adapter: Framed<S::Stream, Codec<Request, Response>>,
) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
loop {
if let Some(incoming_message) = adapter.try_next().await? {
log::debug!("Request: {incoming_message:?}");
let response = match session.handle(incoming_message).await {
Ok(message) => message,
Err(AgentError::ExtensionFailure) => {
log::error!("Extension failure handling message");
Response::ExtensionFailure
}
Err(e) => {
log::error!("Error handling message: {:?}", e);
Response::Failure
}
};
log::debug!("Response: {response:?}");

adapter.send(response).await?;
} else {
// Reached EOF of the stream (client disconnected),
// we can close the socket and exit the handler.
return Ok(());
}
adapter.send(response).await?;
} else {
// Reached EOF of the stream (client disconnected),
// we can close the socket and exit the handler.
return Ok(());
}
}
}
Expand All @@ -265,10 +264,10 @@ pub trait Agent: 'static + Sync + Send + Sized {
loop {
match socket.accept().await {
Ok(socket) => {
let mut session = self.new_session();
let session = self.new_session();
tokio::spawn(async move {
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
if let Err(e) = session.handle_socket::<S>(adapter).await {
if let Err(e) = handle_socket::<S>(session, adapter).await {
log::error!("Agent protocol error: {:?}", e);
}
});
Expand Down
201 changes: 201 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
//! SSH agent client support.

use std::fmt;

use futures::{SinkExt, TryStreamExt};
use ssh_key::Signature;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;

use crate::{
codec::Codec,
error::AgentError,
proto::{
AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained, Extension, Identity,
ProtoError, RemoveIdentity, Request, Response, SignRequest, SmartcardKey,
},
};

/// SSH agent client
#[derive(Debug)]
pub struct Client<Stream>
where
Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
adapter: Framed<Stream, Codec<Response, Request>>,
}

impl<Stream> Client<Stream>
where
Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
/// Create a new SSH agent client wrapping a given socket.
pub fn new(socket: Stream) -> Self {
let adapter = Framed::new(socket, Codec::default());
Self { adapter }
}
}

/// Wrap a stream into an SSH agent client.
pub async fn connect(
stream: service_binding::Stream,
) -> Result<std::pin::Pin<Box<dyn crate::agent::Session>>, Box<dyn std::error::Error>> {
match stream {
#[cfg(unix)]
service_binding::Stream::Unix(stream) => {
let stream = tokio::net::UnixStream::from_std(stream)?;
Ok(Box::pin(Client::new(stream)))
}
service_binding::Stream::Tcp(stream) => {
let stream = tokio::net::TcpStream::from_std(stream)?;
Ok(Box::pin(Client::new(stream)))
}
#[cfg(windows)]
service_binding::Stream::NamedPipe(pipe) => {
use tokio::net::windows::named_pipe::ClientOptions;
let stream = loop {
// https://docs.rs/windows-sys/latest/windows_sys/Win32/Foundation/constant.ERROR_PIPE_BUSY.html
const ERROR_PIPE_BUSY: u32 = 231u32;

// correct way to do it taken from
// https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
match ClientOptions::new().open(&pipe) {
Ok(client) => break client,
Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
Err(e) => Err(e)?,
}

tokio::time::sleep(std::time::Duration::from_millis(50)).await;
};
Ok(Box::pin(Client::new(stream)))
}
#[cfg(not(windows))]
service_binding::Stream::NamedPipe(_) => Err(ProtoError::IO(std::io::Error::other(
"Named pipes supported on Windows only",
))
.into()),
}
}

#[async_trait::async_trait]
impl<Stream> crate::agent::Session for Client<Stream>
where
Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
{
async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
if let Response::IdentitiesAnswer(identities) =
self.handle(Request::RequestIdentities).await?
{
Ok(identities)
} else {
Err(ProtoError::UnexpectedResponse.into())
jcspencer marked this conversation as resolved.
Show resolved Hide resolved
}
}

async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
if let Response::SignResponse(response) = self.handle(Request::SignRequest(request)).await?
{
Ok(response)
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn add_identity(&mut self, identity: AddIdentity) -> Result<(), AgentError> {
if let Response::Success = self.handle(Request::AddIdentity(identity)).await? {
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn add_identity_constrained(
&mut self,
identity: AddIdentityConstrained,
) -> Result<(), AgentError> {
if let Response::Success = self.handle(Request::AddIdConstrained(identity)).await? {
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn remove_identity(&mut self, identity: RemoveIdentity) -> Result<(), AgentError> {
if let Response::Success = self.handle(Request::RemoveIdentity(identity)).await? {
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn remove_all_identities(&mut self) -> Result<(), AgentError> {
if let Response::Success = self.handle(Request::RemoveAllIdentities).await? {
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn add_smartcard_key(&mut self, key: SmartcardKey) -> Result<(), AgentError> {
if let Response::Success = self.handle(Request::AddSmartcardKey(key)).await? {
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn add_smartcard_key_constrained(
&mut self,
key: AddSmartcardKeyConstrained,
) -> Result<(), AgentError> {
if let Response::Success = self
.handle(Request::AddSmartcardKeyConstrained(key))
.await?
{
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn remove_smartcard_key(&mut self, key: SmartcardKey) -> Result<(), AgentError> {
if let Response::Success = self.handle(Request::RemoveSmartcardKey(key)).await? {
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn lock(&mut self, key: String) -> Result<(), AgentError> {
if let Response::Success = self.handle(Request::Lock(key)).await? {
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn unlock(&mut self, key: String) -> Result<(), AgentError> {
if let Response::Success = self.handle(Request::Unlock(key)).await? {
Ok(())
} else {
Err(ProtoError::UnexpectedResponse.into())
}
}

async fn extension(&mut self, extension: Extension) -> Result<Option<Extension>, AgentError> {
match self.handle(Request::Extension(extension)).await? {
Response::Success => Ok(None),
Response::ExtensionResponse(response) => Ok(Some(response)),
_ => Err(ProtoError::UnexpectedResponse.into()),
}
}

async fn handle(&mut self, message: Request) -> Result<Response, AgentError> {
self.adapter.send(message).await?;
if let Some(response) = self.adapter.try_next().await? {
Ok(response)
} else {
Err(ProtoError::IO(std::io::Error::other("server disconnected")).into())
}
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub mod proto;

#[cfg(feature = "agent")]
pub mod agent;
#[cfg(feature = "agent")]
pub mod client;
#[cfg(feature = "codec")]
pub mod codec;
pub mod error;
Expand Down
4 changes: 4 additions & 0 deletions src/proto/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ pub enum ProtoError {
/// Command code that was unsupported.
command: u8,
},

/// The client expected a different response.
#[error("Unexpected response received")]
UnexpectedResponse,
jcspencer marked this conversation as resolved.
Show resolved Hide resolved
}

/// Protocol result.
Expand Down
Loading