Skip to content

Commit

Permalink
Get rid of the Agent and introduce SessionFactory
Browse files Browse the repository at this point in the history
Signed-off-by: Wiktor Kwapisiewicz <[email protected]>
  • Loading branch information
wiktor-k committed Apr 29, 2024
1 parent 4240937 commit 9d04c7e
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 67 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio::net::UnixListener as Listener;
#[cfg(windows)]
use ssh_agent_lib::agent::NamedPipeListener as Listener;
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::agent::{Session, Agent};
use ssh_agent_lib::agent::{Session, listen};
use ssh_agent_lib::proto::{Identity, SignRequest};
use ssh_key::{Algorithm, Signature};
Expand Down Expand Up @@ -50,7 +50,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let _ = std::fs::remove_file(socket); // remove the socket if exists
MyAgent.listen(Listener::bind(socket)?).await?;
listen(Listener::bind(socket)?, MyAgent::default()).await?;
Ok(())
}
```
Expand Down
27 changes: 17 additions & 10 deletions examples/key_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ use rsa::BigUint;
use sha1::Sha1;
#[cfg(windows)]
use ssh_agent_lib::agent::NamedPipeListener as Listener;
use ssh_agent_lib::agent::{ListeningSocket, Session};
use ssh_agent_lib::agent::{listen, Session, SessionFactory};
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::proto::extension::{QueryResponse, RestrictDestination, SessionBind};
use ssh_agent_lib::proto::{
message, signature, AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained,
Credential, Extension, KeyConstraint, RemoveIdentity, SignRequest, SmartcardKey,
};
use ssh_agent_lib::Agent;
use ssh_key::{
private::{KeypairData, PrivateKey},
public::PublicKey,
Expand Down Expand Up @@ -234,11 +233,21 @@ impl KeyStorageAgent {
}
}

impl Agent for KeyStorageAgent {
fn new_session<S>(&mut self, _socket: &S::Stream) -> impl Session
where
S: ListeningSocket + std::fmt::Debug + Send,
{
#[cfg(unix)]
impl SessionFactory<tokio::net::UnixListener> for KeyStorageAgent {
fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
KeyStorage {
identities: Arc::clone(&self.identities),
}
}
}

#[cfg(windows)]
impl SessionFactory<ssh_agent::agent::NamedPipeListener> for KeyStorageAgent {
fn new_session(
&mut self,
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
) -> impl Session {
KeyStorage {
identities: Arc::clone(&self.identities),
}
Expand All @@ -260,8 +269,6 @@ async fn main() -> Result<(), AgentError> {
#[cfg(windows)]
std::fs::File::create("server-started")?;

KeyStorageAgent::new()
.listen(Listener::bind(socket)?)
.await?;
listen(Listener::bind(socket)?, KeyStorageAgent::new()).await?;
Ok(())
}
133 changes: 80 additions & 53 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,69 +249,96 @@ where
}
}

/// Type representing an agent listening for incoming connections.
#[async_trait]
pub trait Agent: 'static + Sync + Send + Sized {
/// Create new session object when a new socket is accepted.
fn new_session<S>(&mut self, socket: &S::Stream) -> impl Session
where
S: ListeningSocket + fmt::Debug + Send;
/// DOC
pub trait SessionFactory<S>: 'static + Send + Sync
where
S: ListeningSocket + fmt::Debug + Send,
{
/// DOC
fn new_session(&mut self, socket: &S::Stream) -> impl Session;
}

/// Listen on a socket waiting for client connections.
async fn listen<S>(mut self, mut socket: S) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
log::info!("Listening; socket = {:?}", socket);
loop {
match socket.accept().await {
Ok(socket) => {
let session = self.new_session::<S>(&socket);
tokio::spawn(async move {
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
if let Err(e) = handle_socket::<S>(session, adapter).await {
log::error!("Agent protocol error: {:?}", e);
}
});
}
Err(e) => {
log::error!("Failed to accept socket: {:?}", e);
return Err(AgentError::IO(e));
}
/// Type representing an agent listening for incoming connections.
pub async fn listen<S>(mut socket: S, mut sf: impl SessionFactory<S>) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
log::info!("Listening; socket = {:?}", socket);
loop {
match socket.accept().await {
Ok(socket) => {
let session = sf.new_session(&socket);
tokio::spawn(async move {
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
if let Err(e) = handle_socket::<S>(session, adapter).await {
log::error!("Agent protocol error: {:?}", e);
}
});
}
Err(e) => {
log::error!("Failed to accept socket: {:?}", e);
return Err(AgentError::IO(e));
}
}
}
}

/// Bind to a service binding listener.
async fn bind(mut self, listener: service_binding::Listener) -> Result<(), AgentError> {
match listener {
#[cfg(unix)]
service_binding::Listener::Unix(listener) => {
self.listen(UnixListener::from_std(listener)?).await
}
service_binding::Listener::Tcp(listener) => {
self.listen(TcpListener::from_std(listener)?).await
}
#[cfg(windows)]
service_binding::Listener::NamedPipe(pipe) => {
self.listen(NamedPipeListener::bind(pipe)?).await
}
#[cfg(not(windows))]
service_binding::Listener::NamedPipe(_) => Err(AgentError::IO(std::io::Error::other(
"Named pipes supported on Windows only",
))),
}
#[cfg(unix)]
impl<T> SessionFactory<tokio::net::UnixListener> for T
where
T: Default + Send + Sync + Session,
{
fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
Self::default()
}
}

impl<T> SessionFactory<tokio::net::TcpListener> for T
where
T: Default + Send + Sync + Session,
{
fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
Self::default()
}
}

impl<T> Agent for T
#[cfg(windows)]
impl<T> SessionFactory<ssh_agent_lib::agent::NamedPipeListener> for T
where
T: Default + Session,
T: Default + Send + Sync + Session,
{
fn new_session<S>(&mut self, _socket: &S::Stream) -> impl Session
where
S: ListeningSocket + fmt::Debug + Send,
{
fn new_session(
&mut self,
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
) -> impl Session {
Self::default()
}
}

/*
/// Bind to a service binding listener.
pub async fn bind<SF>(listener: service_binding::Listener, sf: SF) -> Result<(), AgentError>
where
SF: SessionFactory<tokio::net::UnixListener> + SessionFactory<tokio::net::TcpListener>,
#[cfg(windows)]
SF: SessionFactory<ssh_agent::agent::NamedPipeListener>,
{
match listener {
#[cfg(unix)]
service_binding::Listener::Unix(listener) => {
listen(UnixListener::from_std(listener)?, sf).await
}
service_binding::Listener::Tcp(listener) => {
listen(TcpListener::from_std(listener)?, sf).await
}
#[cfg(windows)]
service_binding::Listener::NamedPipe(pipe) => {
self.listen(NamedPipeListener::bind(pipe)?, sf).await
}
#[cfg(not(windows))]
service_binding::Listener::NamedPipe(_) => Err(AgentError::IO(std::io::Error::other(
"Named pipes supported on Windows only",
))),
}
}
*/
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ pub mod error;
#[cfg(feature = "agent")]
pub use async_trait::async_trait;

#[cfg(feature = "agent")]
pub use self::agent::Agent;
//#[cfg(feature = "agent")]
//pub use self::agent::Agent;

0 comments on commit 9d04c7e

Please sign in to comment.