Skip to content

Commit

Permalink
Merge pull request #60 from wiktor-k/wiktor/add-socket-to-new-session
Browse files Browse the repository at this point in the history
Expose socket info in `new_session`
  • Loading branch information
wiktor-k authored May 13, 2024
2 parents f07a436 + 8c3fb5b commit bd36287
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 99 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ use tokio::net::UnixListener as Listener;
#[cfg(windows)]
use ssh_agent_lib::agent::NamedPipeListener as Listener;
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::agent::{Session, Agent};
use ssh_agent_lib::agent::{Session, listen};
use ssh_agent_lib::proto::{Identity, SignRequest};
use ssh_key::{Algorithm, Signature};
#[derive(Default)]
#[derive(Default, Clone)]
struct MyAgent;
#[ssh_agent_lib::async_trait]
Expand Down Expand Up @@ -50,7 +50,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let _ = std::fs::remove_file(socket); // remove the socket if exists
MyAgent.listen(Listener::bind(socket)?).await?;
listen(Listener::bind(socket)?, MyAgent::default()).await?;
Ok(())
}
```
Expand Down
82 changes: 82 additions & 0 deletions examples/agent-socket-info.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//! This example shows how to access the underlying socket info.
//! The socket info can be used to implement fine-grained access controls based on UID/GID.
//!
//! Run the example with: `cargo run --example agent-socket-info -- -H unix:///tmp/sock`
//! Then inspect the socket info with: `SSH_AUTH_SOCK=/tmp/sock ssh-add -L` which should display
//! something like this:
//!
//! ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA unix: addr: (unnamed) cred: UCred { pid: Some(68463), uid: 1000, gid: 1000 }
use clap::Parser;
use service_binding::Binding;
use ssh_agent_lib::{
agent::{bind, Agent, Session},
error::AgentError,
proto::Identity,
};
use ssh_key::public::KeyData;
use testresult::TestResult;

#[derive(Debug, Default)]
struct AgentSocketInfo {
comment: String,
}

#[ssh_agent_lib::async_trait]
impl Session for AgentSocketInfo {
async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
Ok(vec![Identity {
// this is just a dummy key, the comment is important
pubkey: KeyData::Ed25519(ssh_key::public::Ed25519PublicKey([0; 32])),
comment: self.comment.clone(),
}])
}
}

#[cfg(unix)]
impl Agent<tokio::net::UnixListener> for AgentSocketInfo {
fn new_session(&mut self, socket: &tokio::net::UnixStream) -> impl Session {
Self {
comment: format!(
"unix: addr: {:?} cred: {:?}",
socket.peer_addr().unwrap(),
socket.peer_cred().unwrap()
),
}
}
}

impl Agent<tokio::net::TcpListener> for AgentSocketInfo {
fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
Self {
comment: "tcp".into(),
}
}
}

#[cfg(windows)]
impl Agent<ssh_agent_lib::agent::NamedPipeListener> for AgentSocketInfo {
fn new_session(
&mut self,
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
) -> impl Session {
Self {
comment: "pipe".into(),
}
}
}

#[derive(Debug, Parser)]
struct Args {
#[clap(short = 'H', long)]
host: Binding,
}

#[tokio::main]
async fn main() -> TestResult {
env_logger::init();

let args = Args::parse();
bind(args.host.try_into()?, AgentSocketInfo::default()).await?;
Ok(())
}
28 changes: 3 additions & 25 deletions examples/key_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ use rsa::BigUint;
use sha1::Sha1;
#[cfg(windows)]
use ssh_agent_lib::agent::NamedPipeListener as Listener;
use ssh_agent_lib::agent::Session;
use ssh_agent_lib::agent::{listen, Session};
use ssh_agent_lib::error::AgentError;
use ssh_agent_lib::proto::extension::{QueryResponse, RestrictDestination, SessionBind};
use ssh_agent_lib::proto::{
message, signature, AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained,
Credential, Extension, KeyConstraint, RemoveIdentity, SignRequest, SmartcardKey,
};
use ssh_agent_lib::Agent;
use ssh_key::{
private::{KeypairData, PrivateKey},
public::PublicKey,
Expand All @@ -32,6 +31,7 @@ struct Identity {
comment: String,
}

#[derive(Default, Clone)]
struct KeyStorage {
identities: Arc<Mutex<Vec<Identity>>>,
}
Expand Down Expand Up @@ -225,26 +225,6 @@ impl Session for KeyStorage {
}
}

struct KeyStorageAgent {
identities: Arc<Mutex<Vec<Identity>>>,
}

impl KeyStorageAgent {
fn new() -> Self {
Self {
identities: Arc::new(Mutex::new(vec![])),
}
}
}

impl Agent for KeyStorageAgent {
fn new_session(&mut self) -> impl Session {
KeyStorage {
identities: Arc::clone(&self.identities),
}
}
}

#[tokio::main]
async fn main() -> Result<(), AgentError> {
env_logger::init();
Expand All @@ -260,8 +240,6 @@ async fn main() -> Result<(), AgentError> {
#[cfg(windows)]
std::fs::File::create("server-started")?;

KeyStorageAgent::new()
.listen(Listener::bind(socket)?)
.await?;
listen(Listener::bind(socket)?, KeyStorage::default()).await?;
Ok(())
}
24 changes: 5 additions & 19 deletions examples/openpgp-card-agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,29 @@ use retainer::{Cache, CacheExpiration};
use secrecy::{ExposeSecret, SecretString};
use service_binding::Binding;
use ssh_agent_lib::{
agent::Session,
agent::{bind, Session},
error::AgentError,
proto::{AddSmartcardKeyConstrained, Identity, KeyConstraint, SignRequest, SmartcardKey},
Agent,
};
use ssh_key::{
public::{Ed25519PublicKey, KeyData},
Algorithm, Signature,
};
use testresult::TestResult;

struct CardAgent {
#[derive(Clone)]
struct CardSession {
pwds: Arc<Cache<String, SecretString>>,
}

impl CardAgent {
impl CardSession {
pub fn new() -> Self {
let pwds: Arc<Cache<String, SecretString>> = Arc::new(Default::default());
let clone = Arc::clone(&pwds);
tokio::spawn(async move { clone.monitor(4, 0.25, Duration::from_secs(3)).await });
Self { pwds }
}
}

impl Agent for CardAgent {
fn new_session(&mut self) -> impl Session {
CardSession {
pwds: Arc::clone(&self.pwds),
}
}
}

struct CardSession {
pwds: Arc<Cache<String, SecretString>>,
}

impl CardSession {
async fn handle_sign(
&self,
request: SignRequest,
Expand Down Expand Up @@ -201,6 +187,6 @@ async fn main() -> TestResult {
env_logger::init();

let args = Args::parse();
CardAgent::new().bind(args.host.try_into()?).await?;
bind(args.host.try_into()?, CardSession::new()).await?;
Ok(())
}
139 changes: 90 additions & 49 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,64 +249,105 @@ where
}
}

/// Type representing an agent listening for incoming connections.
#[async_trait]
pub trait Agent: 'static + Sync + Send + Sized {
/// Create new session object when a new socket is accepted.
fn new_session(&mut self) -> impl Session;

/// Listen on a socket waiting for client connections.
async fn listen<S>(mut self, mut socket: S) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
log::info!("Listening; socket = {:?}", socket);
loop {
match socket.accept().await {
Ok(socket) => {
let session = self.new_session();
tokio::spawn(async move {
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
if let Err(e) = handle_socket::<S>(session, adapter).await {
log::error!("Agent protocol error: {:?}", e);
}
});
}
Err(e) => {
log::error!("Failed to accept socket: {:?}", e);
return Err(AgentError::IO(e));
}
/// Factory of sessions for the given type of sockets.
pub trait Agent<S>: 'static + Send + Sync
where
S: ListeningSocket + fmt::Debug + Send,
{
/// Create a [`Session`] object for a given `socket`.
fn new_session(&mut self, socket: &S::Stream) -> impl Session;
}

/// Listen for connections on a given socket and use session factory
/// to create new session for each accepted socket.
pub async fn listen<S>(mut socket: S, mut sf: impl Agent<S>) -> Result<(), AgentError>
where
S: ListeningSocket + fmt::Debug + Send,
{
log::info!("Listening; socket = {:?}", socket);
loop {
match socket.accept().await {
Ok(socket) => {
let session = sf.new_session(&socket);
tokio::spawn(async move {
let adapter = Framed::new(socket, Codec::<Request, Response>::default());
if let Err(e) = handle_socket::<S>(session, adapter).await {
log::error!("Agent protocol error: {:?}", e);
}
});
}
Err(e) => {
log::error!("Failed to accept socket: {:?}", e);
return Err(AgentError::IO(e));
}
}
}
}

/// Bind to a service binding listener.
async fn bind(mut self, listener: service_binding::Listener) -> Result<(), AgentError> {
match listener {
#[cfg(unix)]
service_binding::Listener::Unix(listener) => {
self.listen(UnixListener::from_std(listener)?).await
}
service_binding::Listener::Tcp(listener) => {
self.listen(TcpListener::from_std(listener)?).await
}
#[cfg(windows)]
service_binding::Listener::NamedPipe(pipe) => {
self.listen(NamedPipeListener::bind(pipe)?).await
}
#[cfg(not(windows))]
service_binding::Listener::NamedPipe(_) => Err(AgentError::IO(std::io::Error::other(
"Named pipes supported on Windows only",
))),
#[cfg(unix)]
impl<T> Agent<tokio::net::UnixListener> for T
where
T: Clone + Send + Sync + Session,
{
fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
Self::clone(self)
}
}

impl<T> Agent<tokio::net::TcpListener> for T
where
T: Clone + Send + Sync + Session,
{
fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
Self::clone(self)
}
}

#[cfg(windows)]
impl<T> Agent<NamedPipeListener> for T
where
T: Clone + Send + Sync + Session,
{
fn new_session(
&mut self,
_socket: &tokio::net::windows::named_pipe::NamedPipeServer,
) -> impl Session {
Self::clone(self)
}
}

/// Bind to a service binding listener.
#[cfg(unix)]
pub async fn bind<SF>(listener: service_binding::Listener, sf: SF) -> Result<(), AgentError>
where
SF: Agent<tokio::net::UnixListener> + Agent<tokio::net::TcpListener>,
{
match listener {
#[cfg(unix)]
service_binding::Listener::Unix(listener) => {
listen(UnixListener::from_std(listener)?, sf).await
}
service_binding::Listener::Tcp(listener) => {
listen(TcpListener::from_std(listener)?, sf).await
}
_ => Err(AgentError::IO(std::io::Error::other(
"Unsupported type of a listener.",
))),
}
}

impl<T> Agent for T
/// Bind to a service binding listener.
#[cfg(windows)]
pub async fn bind<SF>(listener: service_binding::Listener, sf: SF) -> Result<(), AgentError>
where
T: Default + Session,
SF: Agent<NamedPipeListener> + Agent<tokio::net::TcpListener>,
{
fn new_session(&mut self) -> impl Session {
Self::default()
match listener {
service_binding::Listener::Tcp(listener) => {
listen(TcpListener::from_std(listener)?, sf).await
}
service_binding::Listener::NamedPipe(pipe) => {
listen(NamedPipeListener::bind(pipe)?, sf).await
}
}
}
3 changes: 0 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,3 @@ pub mod error;

#[cfg(feature = "agent")]
pub use async_trait::async_trait;

#[cfg(feature = "agent")]
pub use self::agent::Agent;

0 comments on commit bd36287

Please sign in to comment.