Skip to content

Commit

Permalink
Merge pull request #15 from wiktor-k/wiktor/error-fixes
Browse files Browse the repository at this point in the history
Make error implement std::io::Error and remove associated Agent's Error type
  • Loading branch information
wiktor-k authored Feb 23, 2024
2 parents c815ce2 + f4bfda4 commit 2f84bdf
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 57 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ processes requests.
use async_trait::async_trait;
use tokio::net::UnixListener;
use ssh_agent_lib::agent::Agent;
use ssh_agent_lib::agent::{Session, Agent};
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::proto::message::{Message, SignRequest};
#[derive(Default)]
struct MyAgent;
#[async_trait]
impl Agent for MyAgent {
type Error = ();
async fn handle(&self, message: Message) -> Result<Message, ()> {
impl Session for MyAgent {
async fn handle(&mut self, message: Message) -> Result<Message, AgentError> {
match message {
Message::SignRequest(request) => {
// get the signature by signing `request.data`
Expand Down
15 changes: 10 additions & 5 deletions examples/key_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use async_trait::async_trait;
use log::info;
use tokio::net::UnixListener;

use ssh_agent_lib::agent::Agent;
use ssh_agent_lib::agent::{Agent, Session};
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::proto::message::{self, Message, SignRequest};
use ssh_agent_lib::proto::private_key::{PrivateKey, RsaPrivateKey};
use ssh_agent_lib::proto::public_key::PublicKey;
Expand Down Expand Up @@ -146,17 +147,21 @@ impl KeyStorage {
}

#[async_trait]
impl Agent for KeyStorage {
type Error = ();

async fn handle(&self, message: Message) -> Result<Message, ()> {
impl Session for KeyStorage {
async fn handle(&mut self, message: Message) -> Result<Message, AgentError> {
self.handle_message(message).or_else(|error| {
println!("Error handling message - {:?}", error);
Ok(Message::Failure)
})
}
}

impl Agent for KeyStorage {
fn new_session(&mut self) -> impl Session {
KeyStorage::new()
}
}

fn rsa_openssl_from_ssh(ssh_rsa: &RsaPrivateKey) -> Result<Rsa<Private>, Box<dyn Error>> {
let n = BigNum::from_slice(&ssh_rsa.n)?;
let e = BigNum::from_slice(&ssh_rsa.e)?;
Expand Down
91 changes: 44 additions & 47 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
use tokio_util::codec::{Decoder, Encoder, Framed};

use std::error::Error;
use std::fmt;
use std::io;
use std::marker::Unpin;
use std::mem::size_of;
use std::sync::Arc;

use super::error::AgentError;
use super::proto::message::Message;
use super::proto::{from_bytes, to_bytes};

struct MessageCodec;
#[derive(Debug)]
pub struct MessageCodec;

impl Decoder for MessageCodec {
type Item = Message;
Expand Down Expand Up @@ -53,39 +52,6 @@ impl Encoder<Message> for MessageCodec {
}
}

struct Session<A, S> {
agent: Arc<A>,
adapter: Framed<S, MessageCodec>,
}

impl<A, S> Session<A, S>
where
A: Agent,
S: AsyncRead + AsyncWrite + Unpin,
{
fn new(agent: Arc<A>, socket: S) -> Self {
let adapter = Framed::new(socket, MessageCodec);
Self { agent, adapter }
}

async fn handle_socket(&mut self) -> Result<(), AgentError> {
loop {
if let Some(incoming_message) = self.adapter.try_next().await? {
let response = self.agent.handle(incoming_message).await.map_err(|e| {
error!("Error handling message; error = {:?}", e);
AgentError::User
})?;

self.adapter.send(response).await?;
} else {
// Reached EOF of the stream (client disconnected),
// we can close the socket and exit the handler.
return Ok(());
}
}
}
}

#[async_trait]
pub trait ListeningSocket {
type Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static;
Expand All @@ -110,35 +76,66 @@ impl ListeningSocket for TcpListener {
}

#[async_trait]
pub trait Agent: 'static + Sync + Send + Sized {
type Error: fmt::Debug + Send + Sync;
pub trait Session: 'static + Sync + Send + Sized {
async fn handle(&mut self, message: Message) -> Result<Message, AgentError>;

async fn handle_socket<S>(
&mut self,
mut adapter: Framed<S::Stream, MessageCodec>,
) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
loop {
if let Some(incoming_message) = adapter.try_next().await? {
let response = self.handle(incoming_message).await.map_err(|e| {
error!("Error handling message; error = {:?}", e);
AgentError::User
})?;

async fn handle(&self, message: Message) -> Result<Message, Self::Error>;
adapter.send(response).await?;
} else {
// Reached EOF of the stream (client disconnected),
// we can close the socket and exit the handler.
return Ok(());
}
}
}
}

async fn listen<S>(self, socket: S) -> Result<(), Box<dyn Error + Send + Sync>>
#[async_trait]
pub trait Agent: 'static + Sync + Send + Sized {
fn new_session(&mut self) -> impl Session;
async fn listen<S>(mut self, socket: S) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
info!("Listening; socket = {:?}", socket);
let arc_self = Arc::new(self);

loop {
match socket.accept().await {
Ok(socket) => {
let agent = arc_self.clone();
let mut session = Session::new(agent, socket);

let mut session = self.new_session();
tokio::spawn(async move {
if let Err(e) = session.handle_socket().await {
let adapter = Framed::new(socket, MessageCodec);
if let Err(e) = session.handle_socket::<S>(adapter).await {
error!("Agent protocol error; error = {:?}", e);
}
});
}
Err(e) => {
error!("Failed to accept socket; error = {:?}", e);
return Err(Box::new(e));
return Err(AgentError::IO(e));
}
}
}
}
}

impl<T> Agent for T
where
T: Default + Session,
{
fn new_session(&mut self) -> impl Session {
Self::default()
}
}
12 changes: 12 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,15 @@ impl From<io::Error> for AgentError {
AgentError::IO(e)
}
}

impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AgentError::User => write!(f, "Agent: User error"),
AgentError::Proto(proto) => write!(f, "Agent: Protocol error: {}", proto),
AgentError::IO(error) => write!(f, "Agent: I/O error: {}", error),
}
}
}

impl std::error::Error for AgentError {}

0 comments on commit 2f84bdf

Please sign in to comment.