diff --git a/README.md b/README.md index 7559f6a..10d1453 100644 --- a/README.md +++ b/README.md @@ -14,16 +14,16 @@ processes requests. use async_trait::async_trait; use tokio::net::UnixListener; -use ssh_agent_lib::agent::Agent; +use ssh_agent_lib::agent::{Session, Agent}; +use ssh_agent_lib::error::AgentError; use ssh_agent_lib::proto::message::{Message, SignRequest}; +#[derive(Default)] struct MyAgent; #[async_trait] -impl Agent for MyAgent { - type Error = (); - - async fn handle(&self, message: Message) -> Result { +impl Session for MyAgent { + async fn handle(&mut self, message: Message) -> Result { match message { Message::SignRequest(request) => { // get the signature by signing `request.data` diff --git a/examples/key_storage.rs b/examples/key_storage.rs index 17b0962..88f455f 100644 --- a/examples/key_storage.rs +++ b/examples/key_storage.rs @@ -2,7 +2,8 @@ use async_trait::async_trait; use log::info; use tokio::net::UnixListener; -use ssh_agent_lib::agent::Agent; +use ssh_agent_lib::agent::{Agent, Session}; +use ssh_agent_lib::error::AgentError; use ssh_agent_lib::proto::message::{self, Message, SignRequest}; use ssh_agent_lib::proto::private_key::{PrivateKey, RsaPrivateKey}; use ssh_agent_lib::proto::public_key::PublicKey; @@ -146,10 +147,8 @@ impl KeyStorage { } #[async_trait] -impl Agent for KeyStorage { - type Error = (); - - async fn handle(&self, message: Message) -> Result { +impl Session for KeyStorage { + async fn handle(&mut self, message: Message) -> Result { self.handle_message(message).or_else(|error| { println!("Error handling message - {:?}", error); Ok(Message::Failure) @@ -157,6 +156,12 @@ impl Agent for KeyStorage { } } +impl Agent for KeyStorage { + fn new_session(&mut self) -> impl Session { + KeyStorage::new() + } +} + fn rsa_openssl_from_ssh(ssh_rsa: &RsaPrivateKey) -> Result, Box> { let n = BigNum::from_slice(&ssh_rsa.n)?; let e = BigNum::from_slice(&ssh_rsa.e)?; diff --git a/src/agent.rs b/src/agent.rs index 0244866..7cd2a33 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -7,18 +7,17 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use tokio_util::codec::{Decoder, Encoder, Framed}; -use std::error::Error; use std::fmt; use std::io; use std::marker::Unpin; use std::mem::size_of; -use std::sync::Arc; use super::error::AgentError; use super::proto::message::Message; use super::proto::{from_bytes, to_bytes}; -struct MessageCodec; +#[derive(Debug)] +pub struct MessageCodec; impl Decoder for MessageCodec { type Item = Message; @@ -53,39 +52,6 @@ impl Encoder for MessageCodec { } } -struct Session { - agent: Arc, - adapter: Framed, -} - -impl Session -where - A: Agent, - S: AsyncRead + AsyncWrite + Unpin, -{ - fn new(agent: Arc, socket: S) -> Self { - let adapter = Framed::new(socket, MessageCodec); - Self { agent, adapter } - } - - async fn handle_socket(&mut self) -> Result<(), AgentError> { - loop { - if let Some(incoming_message) = self.adapter.try_next().await? { - let response = self.agent.handle(incoming_message).await.map_err(|e| { - error!("Error handling message; error = {:?}", e); - AgentError::User - })?; - - self.adapter.send(response).await?; - } else { - // Reached EOF of the stream (client disconnected), - // we can close the socket and exit the handler. - return Ok(()); - } - } - } -} - #[async_trait] pub trait ListeningSocket { type Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static; @@ -110,35 +76,66 @@ impl ListeningSocket for TcpListener { } #[async_trait] -pub trait Agent: 'static + Sync + Send + Sized { - type Error: fmt::Debug + Send + Sync; +pub trait Session: 'static + Sync + Send + Sized { + async fn handle(&mut self, message: Message) -> Result; + + async fn handle_socket( + &mut self, + mut adapter: Framed, + ) -> Result<(), AgentError> + where + S: ListeningSocket + fmt::Debug + Send, + { + loop { + if let Some(incoming_message) = adapter.try_next().await? { + let response = self.handle(incoming_message).await.map_err(|e| { + error!("Error handling message; error = {:?}", e); + AgentError::User + })?; - async fn handle(&self, message: Message) -> Result; + adapter.send(response).await?; + } else { + // Reached EOF of the stream (client disconnected), + // we can close the socket and exit the handler. + return Ok(()); + } + } + } +} - async fn listen(self, socket: S) -> Result<(), Box> +#[async_trait] +pub trait Agent: 'static + Sync + Send + Sized { + fn new_session(&mut self) -> impl Session; + async fn listen(mut self, socket: S) -> Result<(), AgentError> where S: ListeningSocket + fmt::Debug + Send, { info!("Listening; socket = {:?}", socket); - let arc_self = Arc::new(self); - loop { match socket.accept().await { Ok(socket) => { - let agent = arc_self.clone(); - let mut session = Session::new(agent, socket); - + let mut session = self.new_session(); tokio::spawn(async move { - if let Err(e) = session.handle_socket().await { + let adapter = Framed::new(socket, MessageCodec); + if let Err(e) = session.handle_socket::(adapter).await { error!("Agent protocol error; error = {:?}", e); } }); } Err(e) => { error!("Failed to accept socket; error = {:?}", e); - return Err(Box::new(e)); + return Err(AgentError::IO(e)); } } } } } + +impl Agent for T +where + T: Default + Session, +{ + fn new_session(&mut self) -> impl Session { + Self::default() + } +} diff --git a/src/error.rs b/src/error.rs index 4d0c313..5e2293a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -19,3 +19,15 @@ impl From for AgentError { AgentError::IO(e) } } + +impl std::fmt::Display for AgentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AgentError::User => write!(f, "Agent: User error"), + AgentError::Proto(proto) => write!(f, "Agent: Protocol error: {}", proto), + AgentError::IO(error) => write!(f, "Agent: I/O error: {}", error), + } + } +} + +impl std::error::Error for AgentError {}