diff --git a/Cargo.lock b/Cargo.lock index 5cb2083..4e7700c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,11 +2,25 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "dns_encode" version = "0.1.0" dependencies = [ + "bitvec", "dns_types", + "hex", ] [[package]] @@ -23,6 +37,12 @@ dependencies = [ name = "dns_types" version = "0.1.0" +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "hex" version = "0.4.3" @@ -69,6 +89,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "syn" version = "2.0.48" @@ -80,6 +106,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "thiserror" version = "1.0.56" @@ -105,3 +137,12 @@ name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] diff --git a/dns_decode/src/message.rs b/dns_decode/src/message.rs index 9a3084d..ba1ff92 100644 --- a/dns_decode/src/message.rs +++ b/dns_decode/src/message.rs @@ -4,14 +4,14 @@ use nom::{multi::count, IResult}; use thiserror::Error; #[derive(Error, Debug)] -pub enum ParseError { +pub enum DecodeError { #[error("unknown parse error")] Unknown, } -pub fn parse_message(input: &[u8]) -> Result { +pub fn decode_message(input: &[u8]) -> Result { // TODO improve error reporting - let (_, m) = message(input).map_err(|_op| ParseError::Unknown)?; + let (_, m) = message(input).map_err(|_op| DecodeError::Unknown)?; Ok(m) } diff --git a/dns_encode/Cargo.toml b/dns_encode/Cargo.toml index 43e8ef0..b828b4e 100644 --- a/dns_encode/Cargo.toml +++ b/dns_encode/Cargo.toml @@ -6,4 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -dns_types = { path = "../dns_types" } \ No newline at end of file +dns_types = { path = "../dns_types" } +hex = "0.4.3" +bitvec = "1.0.1" diff --git a/dns_encode/src/lib.rs b/dns_encode/src/lib.rs index ea1787a..abc2dbe 100644 --- a/dns_encode/src/lib.rs +++ b/dns_encode/src/lib.rs @@ -1,5 +1,3 @@ -use dns_types::*; -#[cfg(test)] -mod tests { - use super::*; -} +pub mod message; +pub mod message_header; +pub mod query; diff --git a/dns_encode/src/message.rs b/dns_encode/src/message.rs new file mode 100644 index 0000000..e6b77f1 --- /dev/null +++ b/dns_encode/src/message.rs @@ -0,0 +1,51 @@ +use crate::message_header::encode_message_header; +use crate::query::encode_query; +use dns_types::*; +use std::io::{Error, Write}; + +pub fn encode_message(message: &Message, writer: &mut W) -> Result<(), Error> { + encode_message_header(&message.header, writer)?; + for query in message.queries.as_slice() { + encode_query(query, writer)?; + } + // TODO answers + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_message() { + let message_bytes = + hex::decode("690601000001000000000000076578616d706c6503636f6d0000010001").unwrap(); + let message = Message { + header: MessageHeader { + message_id: 0x6906, + flags: Flags { + qr: QR::Query, + opcode: Opcode::Query, + aa: AuthoritativeAnswer::NonAuthoritative, + truncated: Truncated::NotTruncated, + recursion_desired: RecursionDesired::Desired, + recursion_available: RecursionAvailable::NotAvailable, + rcode: Rcode::NoError, + }, + query_count: 1, + answer_count: 0, + name_server_count: 0, + additional_count: 0, + }, + queries: vec![Query { + name: vec![String::from("example"), String::from("com")], + query_type: QueryType::A, + query_class: QueryClass::Internet, + }], + answers: vec![], + }; + let mut buffer: Vec = Vec::with_capacity(50); + encode_message(&message, &mut buffer).unwrap(); + assert_eq!(buffer, message_bytes); + } +} diff --git a/dns_encode/src/message_header.rs b/dns_encode/src/message_header.rs new file mode 100644 index 0000000..55f55ad --- /dev/null +++ b/dns_encode/src/message_header.rs @@ -0,0 +1,68 @@ +use bitvec::prelude::*; +use dns_types::message_header::*; +use std::io::{Error, Write}; + +fn encode_flags(flags: &Flags, writer: &mut W) -> Result<(), Error> { + let mut flags_buffer = [0u8, 0u8]; + let qr: bool = match flags.qr { + QR::Query => false, + QR::Response => true, + }; + let flags_bits = flags_buffer.view_bits_mut::(); + flags_bits.set(0, qr); + // TODO opcode + let rd: bool = match flags.recursion_desired { + RecursionDesired::Desired => true, + RecursionDesired::NotDesired => false, + }; + flags_bits.set(7, rd); + writer.write_all(&flags_buffer)?; + Ok(()) +} + +pub fn encode_message_header( + message_header: &MessageHeader, + writer: &mut W, +) -> Result<(), Error> { + let message_id_bytes = message_header.message_id.to_be_bytes(); + writer.write_all(&message_id_bytes)?; + encode_flags(&message_header.flags, writer)?; + let query_count_bytes = message_header.query_count.to_be_bytes(); + writer.write_all(&query_count_bytes)?; + let answer_count_bytes = message_header.answer_count.to_be_bytes(); + writer.write_all(&answer_count_bytes)?; + let name_server_count_bytes = message_header.name_server_count.to_be_bytes(); + writer.write_all(&name_server_count_bytes)?; + let additional_count_bytes = message_header.additional_count.to_be_bytes(); + writer.write_all(&additional_count_bytes)?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_message_header() { + let message_header_bytes = hex::decode("690601000001000000000000").unwrap(); + let message_header = MessageHeader { + message_id: 0x6906, + flags: Flags { + qr: QR::Query, + opcode: Opcode::Query, + aa: AuthoritativeAnswer::NonAuthoritative, + truncated: Truncated::NotTruncated, + recursion_desired: RecursionDesired::Desired, + recursion_available: RecursionAvailable::NotAvailable, + rcode: Rcode::NoError, + }, + query_count: 1, + answer_count: 0, + name_server_count: 0, + additional_count: 0, + }; + let mut buffer: Vec = Vec::with_capacity(12); + encode_message_header(&message_header, &mut buffer).unwrap(); + assert_eq!(buffer, message_header_bytes); + } +} diff --git a/dns_encode/src/query.rs b/dns_encode/src/query.rs new file mode 100644 index 0000000..728fba7 --- /dev/null +++ b/dns_encode/src/query.rs @@ -0,0 +1,37 @@ +use dns_types::query::*; +use std::io::{Error, Write}; + +pub fn encode_query(query: &Query, writer: &mut W) -> Result<(), Error> { + for label in query.name.as_slice() { + let label_bytes = label.as_bytes(); + let label_len = [label_bytes.len() as u8]; + writer.write_all(&label_len)?; + writer.write_all(label_bytes)?; + } + writer.write_all(b"\x00")?; + let qtype: u16 = query.query_type.into(); + let qtype_bytes = qtype.to_be_bytes(); + writer.write_all(&qtype_bytes)?; + let qclass: u16 = query.query_class.into(); + let qclass_bytes = qclass.to_be_bytes(); + writer.write_all(&qclass_bytes)?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_query() { + let query_bytes = hex::decode("076578616d706c6503636f6d0000010001").unwrap(); + let query = Query { + name: vec![String::from("example"), String::from("com")], + query_type: QueryType::A, + query_class: QueryClass::Internet, + }; + let mut buffer: Vec = Vec::with_capacity(12); + encode_query(&query, &mut buffer).unwrap(); + assert_eq!(buffer, query_bytes); + } +} diff --git a/dns_types/src/query.rs b/dns_types/src/query.rs index ea7799f..9fad2c0 100644 --- a/dns_types/src/query.rs +++ b/dns_types/src/query.rs @@ -7,7 +7,7 @@ pub struct Query { pub query_class: QueryClass, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone, Copy)] pub enum QueryType { A, NS, @@ -40,7 +40,25 @@ impl Into for u16 { } } -#[derive(Debug, PartialEq)] +impl Into for QueryType { + fn into(self) -> u16 { + match self { + QueryType::A => 1, + QueryType::NS => 2, + QueryType::CNAME => 5, + QueryType::SOA => 6, + QueryType::WKS => 11, + QueryType::PTR => 12, + QueryType::MX => 15, + QueryType::SRV => 33, + QueryType::AAAA => 28, + QueryType::ANY => 255, + QueryType::Unknown(u) => u, + } + } +} + +#[derive(Debug, PartialEq, Clone, Copy)] pub enum QueryClass { Internet, Unknown(u16), @@ -54,3 +72,12 @@ impl Into for u16 { } } } + +impl Into for QueryClass { + fn into(self) -> u16 { + match self { + QueryClass::Internet => 1, + QueryClass::Unknown(u) => u, + } + } +}