diff --git a/Cargo.toml b/Cargo.toml index 2510d50..2e8e0ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,5 @@ keywords = ["tftp", "server"] categories = ["command-line-utilities"] [features] -integration = [] \ No newline at end of file +integration = [] +client = [] \ No newline at end of file diff --git a/README.md b/README.md index 32b8967..da469c0 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Since TFTP servers do not offer any type of login or access control mechanisms, Documentation for the project can be found in [docs.rs](https://docs.rs/tftpd/latest/tftpd/). -## Usage +## Usage (Server) To install the server using Cargo: @@ -32,6 +32,32 @@ To run the server on the IP address `0.0.0.0`, read-only, on port `1234` in the tftpd -i 0.0.0.0 -p 1234 -d "/home/user/tftp" -r ``` +## Usage (Client) + +To install the client and server using Cargo: + +```bash +cargo install --features client tftpd +tftpd client --help +tftpd server --help +``` + +To run the server on the IP address `0.0.0.0`, read-only, on port `1234` in the `/home/user/tftp` directory: + +```bash +tftpd server -i 0.0.0.0 -p 1234 -d "/home/user/tftp" -r +``` + +To connect the client to a tftp server running on IP address `127.0.0.1`, read-only, on port `1234` and download a file named `example.file` +```bash +tftpd client example.file -i 0.0.0.0 -p 1234 -d +``` + +To connect the client to a tftp server running on IP address `127.0.0.1`, read-only, on port `1234` and upload a file named `example.file` +```bash +tftpd client ./example.file -i 0.0.0.0 -p 1234 -u +``` + ## License This project is licensed under the [MIT License](https://opensource.org/license/mit/). diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..b404ed9 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,255 @@ +use crate::packet::{DEFAULT_BLOCKSIZE, DEFAULT_TIMEOUT, DEFAULT_WINDOWSIZE}; +use crate::{ClientConfig, OptionType, Packet, Socket, TransferOption, Worker}; +use std::cmp::PartialEq; +use std::error::Error; +use std::fs::File; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; +use std::path::PathBuf; +use std::time::Duration; + +/// Client `struct` is used for client sided TFTP requests. +/// +/// This `struct` is meant to be created by [`Client::new()`]. See its +/// documentation for more. +/// +/// # Example +/// +/// ```rust +/// // Create the TFTP server. +/// use tftpd::{ClientConfig, Client}; +/// +/// let args = ["test.file", "-u"].iter().map(|s| s.to_string()); +/// let config = ClientConfig::new(args).unwrap(); +/// let server = Client::new(&config).unwrap(); +/// ``` +pub struct Client { + remote_address: SocketAddr, + blocksize: usize, + windowsize: u16, + timeout: Duration, + mode: Mode, + filename: PathBuf, + save_path: PathBuf, +} + +/// Enum used to set the client either in Download Mode or Upload Mode +#[derive(PartialEq, Clone, Copy, Debug)] +pub enum Mode { + /// Upload Mode + Upload, + /// Download Mode + Download, +} + +impl Client { + /// Creates the TFTP Client with the supplied [`ClientConfig`]. + pub fn new(config: &ClientConfig) -> Result> { + Ok(Client { + remote_address: SocketAddr::from((config.remote_ip_address, config.port)), + blocksize: config.blocksize, + windowsize: config.windowsize, + timeout: config.timeout, + mode: config.mode, + filename: config.filename.clone(), + save_path: config.save_directory.clone(), + }) + } + + /// Starts the Client depending on the [`Mode`] the client is in + pub fn start(&mut self) -> Result<(), Box> { + match self.mode { + Mode::Upload => self.upload(), + Mode::Download => self.download(), + } + } + + fn upload(&mut self) -> Result<(), Box> { + if self.mode != Mode::Upload { + return Err(Box::from("Client mode is set to Download")); + } + + let socket = if self.remote_address.is_ipv4() { + UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))? + } else { + UdpSocket::bind((Ipv6Addr::UNSPECIFIED, 0))? + }; + let file = self.filename.clone(); + + let size = File::open(self.filename.clone())?.metadata()?.len() as usize; + + Socket::send_to( + &socket, + &Packet::Wrq { + filename: file.into_os_string().into_string().unwrap(), + mode: "octet".into(), + options: vec![ + TransferOption { + option: OptionType::BlockSize, + value: self.blocksize, + }, + TransferOption { + option: OptionType::Windowsize, + value: self.windowsize as usize, + }, + TransferOption { + option: OptionType::Timeout, + value: self.timeout.as_secs() as usize, + }, + TransferOption { + option: OptionType::TransferSize, + value: size, + } + ], + }, + &self.remote_address, + )?; + + let received = Socket::recv_from(&socket); + + if let Ok((packet, from)) = received { + socket.connect(from)?; + match packet { + Packet::Oack(options) => { + self.verify_oack(&options)?; + let worker = self.configure_worker(socket)?; + let join_handle = worker.send(false)?; + let _ = join_handle.join(); + } + Packet::Ack(_) => { + self.blocksize = DEFAULT_BLOCKSIZE; + self.windowsize = DEFAULT_WINDOWSIZE; + self.timeout = DEFAULT_TIMEOUT; + let worker = self.configure_worker(socket)?; + let join_handle = worker.send(false)?; + let _ = join_handle.join(); + } + Packet::Error { code, msg } => { + return Err(Box::from(format!( + "Client received error from server: {code}: {msg}" + ))); + } + _ => { + return Err(Box::from(format!( + "Client received unexpected packet from server: {packet:#?}" + ))); + } + } + } else { + return Err(Box::from("Unexpected Error")); + } + + Ok(()) + } + + fn download(&mut self) -> Result<(), Box> { + if self.mode != Mode::Download { + return Err(Box::from("Client mode is set to Upload")); + } + + let socket = if self.remote_address.is_ipv4() { + UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))? + } else { + UdpSocket::bind((Ipv6Addr::UNSPECIFIED, 0))? + }; + let file = self.filename.clone(); + + Socket::send_to( + &socket, + &Packet::Rrq { + filename: file.into_os_string().into_string().unwrap(), + mode: "octet".into(), + options: vec![ + TransferOption { + option: OptionType::BlockSize, + value: self.blocksize, + }, + TransferOption { + option: OptionType::Windowsize, + value: self.windowsize as usize, + }, + TransferOption { + option: OptionType::Timeout, + value: self.timeout.as_secs() as usize, + }, + TransferOption { + option: OptionType::TransferSize, + value: 0, + } + ], + }, + &self.remote_address, + )?; + + let received = Socket::recv_from(&socket); + + if let Ok((packet, from)) = received { + socket.connect(from)?; + match packet { + Packet::Oack(options) => { + self.verify_oack(&options)?; + Socket::send_to(&socket, &Packet::Ack(0), &from)?; + let worker = self.configure_worker(socket)?; + let join_handle = worker.receive()?; + let _ = join_handle.join(); + } + Packet::Error { code, msg } => { + return Err(Box::from(format!( + "Client received error from server: {code}: {msg}" + ))); + } + _ => { + return Err(Box::from(format!( + "Client received unexpected packet from server: {packet:#?}" + ))); + } + } + } else { + return Err(Box::from("Unexpected Error")); + } + + Ok(()) + } + + fn verify_oack(&mut self, options: &Vec) -> Result<(), Box> { + for option in options { + match option.option { + OptionType::BlockSize {} => self.blocksize = option.value, + OptionType::Windowsize => self.windowsize = option.value as u16, + _ => {} + } + } + + Ok(()) + } + + fn configure_worker(&self, socket: UdpSocket) -> Result, Box> { + let mut socket: Box = Box::new(socket); + + socket.set_read_timeout(self.timeout)?; + socket.set_write_timeout(self.timeout)?; + + let worker = if self.mode == Mode::Download { + let mut file = self.save_path.clone(); + file = file.join(self.filename.clone()); + Worker::new( + socket, + file, + self.blocksize, + DEFAULT_TIMEOUT, + self.windowsize, + 1, + ) + } else { + Worker::new( + socket, + PathBuf::from(self.filename.clone()), + self.blocksize, + DEFAULT_TIMEOUT, + self.windowsize, + 1, + ) + }; + + Ok(worker) + } +} diff --git a/src/client_config.rs b/src/client_config.rs new file mode 100644 index 0000000..179d225 --- /dev/null +++ b/src/client_config.rs @@ -0,0 +1,216 @@ +use crate::client::Mode; +use crate::packet::{DEFAULT_BLOCKSIZE, DEFAULT_TIMEOUT, DEFAULT_WINDOWSIZE}; +use std::error::Error; +use std::net::{IpAddr, Ipv4Addr}; +use std::path::{Path, PathBuf}; +use std::process; +use std::time::Duration; + +/// Configuration `struct` used for parsing TFTP Client options from user +/// input. +/// +/// This `struct` is meant to be created by [`ClientConfig::new()`]. See its +/// documentation for more. +/// +/// # Example +/// +/// ```rust +/// // Create TFTP configuration from user arguments. +/// use std::env; +/// use tftpd::ClientConfig; +/// +/// let client_config = ClientConfig::new(env::args()); +/// ``` +#[derive(Debug)] +pub struct ClientConfig { + /// Local IP address of the TFTP Server. (default: 127.0.0.1) + pub remote_ip_address: IpAddr, + /// Local Port number of the TFTP Server. (default: 69) + pub port: u16, + /// Blocksize to use during transfer. (default: 512) + pub blocksize: usize, + /// Windowsize to use during transfer. (default: 1) + pub windowsize: u16, + /// Timeout to use during transfer. (default: 5s) + pub timeout: Duration, + /// Upload or Download a file. (default: Download) + pub mode: Mode, + /// Directory where to save downloaded files. (default: Current Working Directory) + pub save_directory: PathBuf, + /// File to Upload or Download. + pub filename: PathBuf, +} + +impl Default for ClientConfig { + fn default() -> Self { + Self { + remote_ip_address: IpAddr::V4(Ipv4Addr::LOCALHOST), + port: 69, + blocksize: DEFAULT_BLOCKSIZE, + windowsize: DEFAULT_WINDOWSIZE, + timeout: DEFAULT_TIMEOUT, + mode: Mode::Download, + save_directory: Default::default(), + filename: Default::default(), + } + } +} + +impl ClientConfig { + /// Creates a new configuration by parsing the supplied arguments. It is + /// intended for use with [`env::args()`]. + pub fn new>(mut args: T) -> Result> { + let mut config = ClientConfig::default(); + + args.next(); + + if let Some(file_name) = args.next() { + config.filename = PathBuf::from(file_name); + } else { + return Err("Missing file to upload or download".into()); + } + + while let Some(arg) = args.next() { + match arg.as_str() { + "-i" | "--ip-address" => { + if let Some(ip_str) = args.next() { + let ip_addr: IpAddr = ip_str.parse()?; + config.remote_ip_address = ip_addr; + } else { + return Err("Missing ip address after flag".into()); + } + } + "-p" | "--port" => { + if let Some(port_str) = args.next() { + config.port = port_str.parse::()?; + } else { + return Err("Missing port number after flag".into()); + } + } + "-b" | "--blocksize" => { + if let Some(blocksize_str) = args.next() { + config.blocksize = blocksize_str.parse::()?; + } else { + return Err("Missing blocksize after flag".into()); + } + } + "-w" | "--windowsize" => { + if let Some(windowsize_str) = args.next() { + config.windowsize = windowsize_str.parse::()?; + } else { + return Err("Missing windowsize after flag".into()); + } + } + "-t" | "--timeout" => { + if let Some(timeout_str) = args.next() { + config.timeout = Duration::from_secs(timeout_str.parse::()?); + } else { + return Err("Missing timeout after flag".into()); + } + } + "-sd" | "--save-directory" => { + if let Some(dir_str) = args.next() { + if !Path::new(&dir_str).exists() { + return Err(format!("{dir_str} does not exist").into()); + } + config.save_directory = dir_str.into(); + } else { + return Err("Missing save directory after flag".into()); + } + } + "-u" | "--upload" => { + config.mode = Mode::Upload; + } + "-d" | "--download" => { + config.mode = Mode::Download; + } + "-h" | "--help" => { + println!("TFTP Client\n"); + println!("Usage: tftpd client [OPTIONS]\n"); + println!("Options:"); + println!(" -i, --ip-address \tIp address of the server (default: 127.0.0.1)"); + println!(" -p, --port \t\tPort of the server (default: 69)"); + println!(" -b, --blocksize \tSets the blocksize (default: 512)"); + println!(" -w, --windowsize \tSets the windowsize (default: 1)"); + println!(" -t, --timeout \tSets the timeout in seconds (default: 5)"); + println!(" -u, --upload\t\t\tSets the client to upload mode, Ignores all previous download flags"); + println!(" -d, --download\t\tSet the client to download mode, Invalidates all previous upload flags"); + println!(" -sd, --save-directory \tSet the directory to save files when in Download Mode (default: the directory setting)"); + println!(" -h, --help\t\t\tPrint help information"); + process::exit(0); + } + + invalid => return Err(format!("Invalid flag: {invalid}").into()), + } + } + + Ok(config) + } +} + +#[cfg(test)] +mod tests { + use ClientConfig; + + use super::*; + + #[test] + fn parses_full_config() { + let config = ClientConfig::new( + [ + "client", + "test.file", + "-i", + "0.0.0.0", + "-p", + "1234", + "-sd", + "/", + "-d", + "-u", + "-b", + "1024", + "-w", + "2", + "-t", + "4" + ] + .iter() + .map(|s| s.to_string()), + ) + .unwrap(); + + assert_eq!(config.remote_ip_address, Ipv4Addr::new(0, 0, 0, 0)); + assert_eq!(config.port, 1234); + assert_eq!(config.save_directory, PathBuf::from("/")); + assert_eq!(config.filename, PathBuf::from("test.file")); + assert_eq!(config.windowsize, 2); + assert_eq!(config.blocksize, 1024); + assert_eq!(config.mode, Mode::Upload); + assert_eq!(config.timeout, Duration::from_secs(4)); + } + + + #[test] + fn parses_partial_config() { + let config = ClientConfig::new( + [ + "client", + "test.file", + "-d", + "-b", + "2048", + "-p", + "2000", + ] + .iter() + .map(|s| s.to_string()), + ) + .unwrap(); + + assert_eq!(config.port, 2000); + assert_eq!(config.filename, PathBuf::from("test.file")); + assert_eq!(config.blocksize, 2048); + assert_eq!(config.mode, Mode::Download); + } +} diff --git a/src/config.rs b/src/config.rs index 55ad422..167cdca 100644 --- a/src/config.rs +++ b/src/config.rs @@ -118,6 +118,9 @@ impl Config { } "-h" | "--help" => { println!("TFTP Server Daemon\n"); + #[cfg(feature = "client")] + println!("Usage: tftpd server [OPTIONS]\n"); + #[cfg(not(feature = "client"))] println!("Usage: tftpd [OPTIONS]\n"); println!("Options:"); println!(" -i, --ip-address \tSet the ip address of the server (default: 127.0.0.1)"); diff --git a/src/lib.rs b/src/lib.rs index 51c6d51..fe0590c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,11 @@ //! Since TFTP servers do not offer any type of login or access control mechanisms, this server only allows //! transfer and receiving inside a chosen folder, and disallows external file access. +#[cfg(feature = "client")] +mod client; + +#[cfg(feature = "client")] +mod client_config; mod config; mod convert; mod packet; @@ -23,6 +28,12 @@ mod socket; mod window; mod worker; +#[cfg(feature = "client")] +pub use client::Client; +#[cfg(feature = "client")] +pub use client::Mode; +#[cfg(feature = "client")] +pub use client_config::ClientConfig; pub use config::Config; pub use convert::Convert; pub use packet::ErrorCode; diff --git a/src/main.rs b/src/main.rs index 4230348..ffa04d0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,67 @@ +#[cfg(feature = "client")] +use std::error::Error; use std::{env, net::SocketAddr, process}; +#[cfg(not(feature = "client"))] use tftpd::{Config, Server}; +#[cfg(feature = "client")] +use tftpd::{Client, ClientConfig, Config, Mode, Server}; +#[cfg(feature = "client")] fn main() { - let config = Config::new(env::args()).unwrap_or_else(|err| { + let args: Vec = env::args().collect(); + if args.len() < 2 { + eprintln!("{}: incorrect usage", args[0]); + eprintln!("{} [args]", args[0]); + } else if args[1] == "client" { + client(args[1..].iter().map(|s| s.to_string())).unwrap_or_else(|err| { + eprintln!("{err}"); + }) + } else if args[1] == "server" { + server(args[1..].iter().map(|s| s.to_string())); + } else { + eprintln!("{}: incorrect usage", args[0]); + eprintln!("{} (client | server) [args]", args[0]); + } +} + + +#[cfg(not(feature = "client"))] +fn main() { + let args: Vec = env::args().collect(); + server(args[0..].iter().map(|s| s.to_string())); +} + +#[cfg(feature = "client")] +fn client>(args: T) -> Result<(), Box> { + let config = ClientConfig::new(args).unwrap_or_else(|err| { + eprintln!("Problem parsing arguments: {err}"); + process::exit(1) + }); + + let mut server = Client::new(&config).unwrap_or_else(|err| { + eprintln!("Problem creating client: {err}"); + process::exit(1) + }); + + if config.mode == Mode::Upload { + println!( + "Starting TFTP Client, uploading {} to {}", + config.filename.display(), + SocketAddr::new(config.remote_ip_address, config.port), + ); + } else { + println!( + "Starting TFTP Client, downloading {} to {}", + config.filename.display(), + SocketAddr::new(config.remote_ip_address, config.port), + ); + } + + server.start() +} + +fn server>(args: T) { + let config = Config::new(args).unwrap_or_else(|err| { eprintln!("Problem parsing arguments: {err}"); process::exit(1) }); diff --git a/src/packet.rs b/src/packet.rs index 7004287..3f0764b 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -1,4 +1,6 @@ use crate::Convert; +#[cfg(feature = "client")] +use std::time::Duration; use std::{error::Error, fmt, str::FromStr}; /// Packet `enum` represents the valid TFTP packet types. @@ -64,19 +66,28 @@ impl Packet { Opcode::Rrq | Opcode::Wrq => parse_rq(buf, opcode), Opcode::Data => parse_data(buf), Opcode::Ack => parse_ack(buf), + Opcode::Oack => parse_oack(buf), Opcode::Error => parse_error(buf), - _ => Err("Invalid packet".into()), } } /// Serializes a [`Packet`] into a [`Vec`]. pub fn serialize(&self) -> Result, &'static str> { match self { + Packet::Rrq { + filename, + mode, + options, + } => Ok(serialize_rrq(filename, mode, options)), + Packet::Wrq { + filename, + mode, + options, + } => Ok(serialize_wrq(filename, mode, options)), Packet::Data { block_num, data } => Ok(serialize_data(block_num, data)), Packet::Ack(block_num) => Ok(serialize_ack(block_num)), Packet::Error { code, msg } => Ok(serialize_error(code, msg)), Packet::Oack(options) => Ok(serialize_oack(options)), - _ => Err("Invalid packet"), } } } @@ -192,6 +203,18 @@ pub enum OptionType { Windowsize, } +#[cfg(feature = "client")] +/// Default Timeout +pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); + +#[cfg(feature = "client")] +/// Default Blocksize +pub const DEFAULT_BLOCKSIZE: usize = 512; + +#[cfg(feature = "client")] +/// Default Windowsize +pub const DEFAULT_WINDOWSIZE: u16 = 1; + impl OptionType { /// Converts an [`OptionType`] to a [`str`]. pub fn as_str(&self) -> &'static str { @@ -339,6 +362,26 @@ fn parse_ack(buf: &[u8]) -> Result> { Ok(Packet::Ack(Convert::to_u16(&buf[2..])?)) } +fn parse_oack(buf: &[u8]) -> Result> { + let mut options = vec![]; + let mut value: String; + let mut option; + let mut zero_index = 1usize; + + while zero_index < buf.len() - 1 { + (option, zero_index) = Convert::to_string(buf, zero_index + 1)?; + (value, zero_index) = Convert::to_string(buf, zero_index + 1)?; + if let Ok(option) = OptionType::from_str(option.to_lowercase().as_str()) { + options.push(TransferOption { + option, + value: value.parse()?, + }); + } + } + + Ok(Packet::Oack(options)) +} + fn parse_error(buf: &[u8]) -> Result> { let code = ErrorCode::from_u16(Convert::to_u16(&buf[2..])?)?; if let Ok((msg, _)) = Convert::to_string(buf, 4) { @@ -351,6 +394,38 @@ fn parse_error(buf: &[u8]) -> Result> { } } +fn serialize_rrq(file: &String, mode: &String, options: &Vec) -> Vec { + let mut buf = [ + &Opcode::Rrq.as_bytes(), + file.as_bytes(), + &[0x00], + mode.as_bytes(), + &[0x00], + ] + .concat(); + + for option in options { + buf = [buf, option.as_bytes()].concat(); + } + buf +} + +fn serialize_wrq(file: &String, mode: &String, options: &Vec) -> Vec { + let mut buf = [ + &Opcode::Wrq.as_bytes(), + file.as_bytes(), + &[0x00], + mode.as_bytes(), + &[0x00], + ] + .concat(); + + for option in options { + buf = [buf, option.as_bytes()].concat(); + } + buf +} + fn serialize_data(block_num: &u16, data: &Vec) -> Vec { [ &Opcode::Data.as_bytes(), @@ -576,6 +651,53 @@ mod tests { } } + #[test] + fn parses_oack() { + let buf = [ + &Opcode::Oack.as_bytes()[..], + (OptionType::TransferSize.as_str().as_bytes()), + &[0x00], + ("0".as_bytes()), + &[0x00], + (OptionType::Timeout.as_str().as_bytes()), + &[0x00], + ("5".as_bytes()), + &[0x00], + (OptionType::Windowsize.as_str().as_bytes()), + &[0x00], + ("4".as_bytes()), + &[0x00], + ] + .concat(); + + if let Ok(Packet::Oack(options)) = parse_oack(&buf) { + assert_eq!(options.len(), 3); + assert_eq!( + options[0], + TransferOption { + option: OptionType::TransferSize, + value: 0 + } + ); + assert_eq!( + options[1], + TransferOption { + option: OptionType::Timeout, + value: 5 + } + ); + assert_eq!( + options[2], + TransferOption { + option: OptionType::Windowsize, + value: 4 + } + ); + } else { + panic!("cannot parse read request with options") + } + } + #[test] fn parses_error() { let buf = [ @@ -611,6 +733,98 @@ mod tests { } } + #[test] + fn serializes_rrq() { + let serialized_data = vec![0x00, 0x01, 0x74, 0x65, 0x73, 0x74, 0x00, 0x6f, 0x63, 0x74, 0x65, 0x74, 0x00]; + + assert_eq!( + serialize_rrq( + &"test".into(), + &"octet".into(), + &vec![] + ), + serialized_data + ) + } + + #[test] + fn serializes_rrq_with_options() { + let serialized_data = vec![ + 0x00, 0x01, 0x74, 0x65, 0x73, 0x74, 0x00, 0x6f, 0x63, 0x74, 0x65, 0x74, 0x00, 0x62, + 0x6c, 0x6b, 0x73, 0x69, 0x7a, 0x65, 0x00, 0x31, 0x34, 0x36, 0x38, 0x00, 0x77, 0x69, + 0x6e, 0x64, 0x6f, 0x77, 0x73, 0x69, 0x7a, 0x65, 0x00, 0x31, 0x00, 0x74, 0x69, 0x6d, + 0x65, 0x6f, 0x75, 0x74, 0x00, 0x35, 0x00, + ]; + + assert_eq!( + serialize_rrq( + &"test".into(), + &"octet".into(), + &vec![ + TransferOption { + option: OptionType::BlockSize, + value: 1468, + }, + TransferOption { + option: OptionType::Windowsize, + value: 1, + }, + TransferOption { + option: OptionType::Timeout, + value: 5, + } + ] + ), + serialized_data + ) + } + + #[test] + fn serializes_wrq() { + let serialized_data = vec![0x00, 0x02, 0x74, 0x65, 0x73, 0x74, 0x00, 0x6f, 0x63, 0x74, 0x65, 0x74, 0x00]; + + assert_eq!( + serialize_wrq( + &"test".into(), + &"octet".into(), + &vec![] + ), + serialized_data + ) + } + + #[test] + fn serializes_wrq_with_options() { + let serialized_data = vec![ + 0x00, 0x02, 0x74, 0x65, 0x73, 0x74, 0x00, 0x6f, 0x63, 0x74, 0x65, 0x74, 0x00, 0x62, + 0x6c, 0x6b, 0x73, 0x69, 0x7a, 0x65, 0x00, 0x31, 0x34, 0x36, 0x38, 0x00, 0x77, 0x69, + 0x6e, 0x64, 0x6f, 0x77, 0x73, 0x69, 0x7a, 0x65, 0x00, 0x31, 0x00, 0x74, 0x69, 0x6d, + 0x65, 0x6f, 0x75, 0x74, 0x00, 0x35, 0x00, + ]; + + assert_eq!( + serialize_wrq( + &"test".into(), + &"octet".into(), + &vec![ + TransferOption { + option: OptionType::BlockSize, + value: 1468, + }, + TransferOption { + option: OptionType::Windowsize, + value: 1, + }, + TransferOption { + option: OptionType::Timeout, + value: 5, + } + ] + ), + serialized_data + ) + } + #[test] fn serializes_data() { let serialized_data = vec![0x00, 0x03, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04]; diff --git a/src/server.rs b/src/server.rs index 4705703..922117d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -184,7 +184,8 @@ impl Server { worker_options.window_size, self.duplicate_packets + 1, ); - worker.send(!options.is_empty()) + worker.send(!options.is_empty())?; + Ok(()) } _ => Err("Unexpected error code when checking file".into()), } @@ -224,7 +225,8 @@ impl Server { worker_options.window_size, self.duplicate_packets + 1, ); - worker.receive() + worker.receive()?; + Ok(()) }; match check_file_exists(file_path, &self.receive_directory) { diff --git a/src/worker.rs b/src/worker.rs index 82bd952..7ec0161 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,4 +1,5 @@ use crate::{ErrorCode, Packet, Socket, Window}; +use std::thread::JoinHandle; use std::{ error::Error, fs::{self, File}, @@ -70,11 +71,11 @@ impl Worker { /// Sends a file to the remote [`SocketAddr`] that has sent a read request using /// a random port, asynchronously. - pub fn send(self, check_response: bool) -> Result<(), Box> { + pub fn send(self, check_response: bool) -> Result, Box> { let file_name = self.file_name.clone(); let remote_addr = self.socket.remote_addr().unwrap(); - thread::spawn(move || { + let handle = thread::spawn(move || { let handle_send = || -> Result<(), Box> { self.send_file(File::open(&file_name)?, check_response)?; @@ -95,16 +96,16 @@ impl Worker { } }); - Ok(()) + Ok(handle) } /// Receives a file from the remote [`SocketAddr`] that has sent a write request using /// the supplied socket, asynchronously. - pub fn receive(self) -> Result<(), Box> { + pub fn receive(self) -> Result, Box> { let file_name = self.file_name.clone(); let remote_addr = self.socket.remote_addr().unwrap(); - thread::spawn(move || { + let handle = thread::spawn(move || { let handle_receive = || -> Result<(), Box> { self.receive_file(File::create(&file_name)?)?; @@ -128,7 +129,7 @@ impl Worker { } }); - Ok(()) + Ok(handle) } fn send_file(self, file: File, check_response: bool) -> Result<(), Box> {