diff --git a/src/proto/extension.rs b/src/proto/extension.rs index 4c939df..606a5ba 100644 --- a/src/proto/extension.rs +++ b/src/proto/extension.rs @@ -1,6 +1,9 @@ -use ssh_encoding::{Decode, Reader}; +use ssh_encoding::{CheckedSum, Decode, Encode, Error as EncodingError, Reader, Writer}; use ssh_key::{public::KeyData, Signature}; +// Reserved fields are marked with an empty string +const RESERVED_FIELD: &str = ""; + /// session-bind@openssh.com extension /// /// This extension allows a ssh client to bind an agent connection to a @@ -32,6 +35,30 @@ impl Decode for SessionBind { } } +impl Encode for SessionBind { + fn encoded_len(&self) -> ssh_encoding::Result { + [ + self.host_key.encoded_len_prefixed()?, + self.session_id.encoded_len()?, + self.signature.encoded_len_prefixed()?, + 1u8.encoded_len()?, + ] + .checked_sum() + } + + fn encode(&self, writer: &mut impl Writer) -> ssh_encoding::Result<()> { + self.host_key.encode_prefixed(writer)?; + self.session_id.encode(writer)?; + self.signature.encode_prefixed(writer)?; + + if self.is_forwarding { + 1u8.encode(writer) + } else { + 0u8.encode(writer) + } + } +} + #[derive(Debug, Clone)] pub struct RestrictDestination { pub constraints: Vec, @@ -49,50 +76,111 @@ impl Decode for RestrictDestination { } } +impl Encode for RestrictDestination { + fn encoded_len(&self) -> ssh_encoding::Result { + self.constraints.iter().try_fold(0, |acc, e| { + let constraint_len = e.encoded_len()?; + usize::checked_add(acc, constraint_len).ok_or(EncodingError::Length) + }) + } + + fn encode(&self, writer: &mut impl Writer) -> ssh_encoding::Result<()> { + for constraint in &self.constraints { + constraint.encode(writer)?; + } + Ok(()) + } +} + #[derive(Debug, Clone)] -pub struct DestinationConstraint { - pub from_username: String, - pub from_hostname: String, - pub from_hostkeys: Vec, - pub to_username: String, - pub to_hostname: String, - pub to_hostkeys: Vec, +pub struct HostTuple { + pub username: String, + pub hostname: String, + pub keys: Vec, } -impl Decode for DestinationConstraint { +impl Decode for HostTuple { type Error = crate::proto::error::ProtoError; fn decode(reader: &mut impl Reader) -> Result { - fn read_user_host_keys( - reader: &mut impl Reader, - ) -> Result<(String, String, Vec), crate::proto::error::ProtoError> { - let username = String::decode(reader)?; - let hostname = String::decode(reader)?; - let _reserved = String::decode(reader)?; - - let mut keys = Vec::new(); - while !reader.is_finished() { - keys.push(KeySpec::decode(reader)?); - } - - Ok((username, hostname, keys)) + let username = String::decode(reader)?; + let hostname = String::decode(reader)?; + let _reserved = String::decode(reader)?; + + let mut keys = Vec::new(); + while !reader.is_finished() { + keys.push(KeySpec::decode(reader)?); } - let (from_username, from_hostname, from_hostkeys) = - reader.read_prefixed(read_user_host_keys)?; - let (to_username, to_hostname, to_hostkeys) = reader.read_prefixed(read_user_host_keys)?; - let _reserved = String::decode(reader)?; Ok(Self { - from_username, - from_hostname, - from_hostkeys, - to_username, - to_hostname, - to_hostkeys, + username, + hostname, + keys, }) } } +impl Encode for HostTuple { + fn encoded_len(&self) -> ssh_encoding::Result { + let prefix = [ + self.username.encoded_len()?, + self.hostname.encoded_len()?, + RESERVED_FIELD.encoded_len()?, + ] + .checked_sum()?; + self.keys.iter().try_fold(prefix, |acc, e| { + let key_len = e.encoded_len()?; + usize::checked_add(acc, key_len).ok_or(EncodingError::Length) + }) + } + + fn encode(&self, writer: &mut impl Writer) -> ssh_encoding::Result<()> { + self.username.encode(writer)?; + self.hostname.encode(writer)?; + RESERVED_FIELD.encode(writer)?; + for key in &self.keys { + key.encode(writer)?; + } + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct DestinationConstraint { + pub from: HostTuple, + pub to: HostTuple, +} + +impl Decode for DestinationConstraint { + type Error = crate::proto::error::ProtoError; + + fn decode(reader: &mut impl Reader) -> Result { + let from = reader.read_prefixed(HostTuple::decode)?; + let to = reader.read_prefixed(HostTuple::decode)?; + let _reserved = String::decode(reader)?; + + Ok(Self { from, to }) + } +} + +impl Encode for DestinationConstraint { + fn encoded_len(&self) -> ssh_encoding::Result { + [ + self.from.encoded_len_prefixed()?, + self.to.encoded_len_prefixed()?, + RESERVED_FIELD.encoded_len()?, + ] + .checked_sum() + } + + fn encode(&self, writer: &mut impl Writer) -> ssh_encoding::Result<()> { + self.from.encode_prefixed(writer)?; + self.to.encode_prefixed(writer)?; + RESERVED_FIELD.encode(writer)?; + Ok(()) + } +} + #[derive(Debug, Clone)] pub struct KeySpec { pub keyblob: KeyData, @@ -111,6 +199,23 @@ impl Decode for KeySpec { } } +impl Encode for KeySpec { + fn encoded_len(&self) -> ssh_encoding::Result { + [self.keyblob.encoded_len_prefixed()?, 1u8.encoded_len()?].checked_sum() + } + + fn encode(&self, writer: &mut impl Writer) -> ssh_encoding::Result<()> { + self.keyblob.encode_prefixed(writer)?; + // TODO: contribute `impl Encode for bool` in ssh-encoding + // + if self.is_ca { + 1u8.encode(writer) + } else { + 0u8.encode(writer) + } + } +} + #[cfg(test)] mod tests { use hex_literal::hex;