Skip to content

Commit

Permalink
implement encoder for queries
Browse files Browse the repository at this point in the history
  • Loading branch information
BrendanBall committed Feb 4, 2024
1 parent aeeb69a commit 981382e
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 11 deletions.
41 changes: 41 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions dns_decode/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use nom::{multi::count, IResult};
use thiserror::Error;

#[derive(Error, Debug)]
pub enum ParseError {
pub enum DecodeError {
#[error("unknown parse error")]
Unknown,
}

pub fn parse_message(input: &[u8]) -> Result<Message, ParseError> {
pub fn decode_message(input: &[u8]) -> Result<Message, DecodeError> {
// TODO improve error reporting
let (_, m) = message(input).map_err(|_op| ParseError::Unknown)?;
let (_, m) = message(input).map_err(|_op| DecodeError::Unknown)?;
Ok(m)
}

Expand Down
4 changes: 3 additions & 1 deletion dns_encode/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
dns_types = { path = "../dns_types" }
dns_types = { path = "../dns_types" }
hex = "0.4.3"
bitvec = "1.0.1"
8 changes: 3 additions & 5 deletions dns_encode/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use dns_types::*;
#[cfg(test)]
mod tests {
use super::*;
}
pub mod message;
pub mod message_header;
pub mod query;
51 changes: 51 additions & 0 deletions dns_encode/src/message.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use crate::message_header::encode_message_header;
use crate::query::encode_query;
use dns_types::*;
use std::io::{Error, Write};

pub fn encode_message<W: Write>(message: &Message, writer: &mut W) -> Result<(), Error> {
encode_message_header(&message.header, writer)?;
for query in message.queries.as_slice() {
encode_query(query, writer)?;
}
// TODO answers
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_encode_message() {
let message_bytes =
hex::decode("690601000001000000000000076578616d706c6503636f6d0000010001").unwrap();
let message = Message {
header: MessageHeader {
message_id: 0x6906,
flags: Flags {
qr: QR::Query,
opcode: Opcode::Query,
aa: AuthoritativeAnswer::NonAuthoritative,
truncated: Truncated::NotTruncated,
recursion_desired: RecursionDesired::Desired,
recursion_available: RecursionAvailable::NotAvailable,
rcode: Rcode::NoError,
},
query_count: 1,
answer_count: 0,
name_server_count: 0,
additional_count: 0,
},
queries: vec![Query {
name: vec![String::from("example"), String::from("com")],
query_type: QueryType::A,
query_class: QueryClass::Internet,
}],
answers: vec![],
};
let mut buffer: Vec<u8> = Vec::with_capacity(50);
encode_message(&message, &mut buffer).unwrap();
assert_eq!(buffer, message_bytes);
}
}
68 changes: 68 additions & 0 deletions dns_encode/src/message_header.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use bitvec::prelude::*;
use dns_types::message_header::*;
use std::io::{Error, Write};

fn encode_flags<W: Write>(flags: &Flags, writer: &mut W) -> Result<(), Error> {
let mut flags_buffer = [0u8, 0u8];
let qr: bool = match flags.qr {
QR::Query => false,
QR::Response => true,
};
let flags_bits = flags_buffer.view_bits_mut::<Msb0>();
flags_bits.set(0, qr);
// TODO opcode
let rd: bool = match flags.recursion_desired {
RecursionDesired::Desired => true,
RecursionDesired::NotDesired => false,
};
flags_bits.set(7, rd);
writer.write_all(&flags_buffer)?;
Ok(())
}

pub fn encode_message_header<W: Write>(
message_header: &MessageHeader,
writer: &mut W,
) -> Result<(), Error> {
let message_id_bytes = message_header.message_id.to_be_bytes();
writer.write_all(&message_id_bytes)?;
encode_flags(&message_header.flags, writer)?;
let query_count_bytes = message_header.query_count.to_be_bytes();
writer.write_all(&query_count_bytes)?;
let answer_count_bytes = message_header.answer_count.to_be_bytes();
writer.write_all(&answer_count_bytes)?;
let name_server_count_bytes = message_header.name_server_count.to_be_bytes();
writer.write_all(&name_server_count_bytes)?;
let additional_count_bytes = message_header.additional_count.to_be_bytes();
writer.write_all(&additional_count_bytes)?;
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_encode_message_header() {
let message_header_bytes = hex::decode("690601000001000000000000").unwrap();
let message_header = MessageHeader {
message_id: 0x6906,
flags: Flags {
qr: QR::Query,
opcode: Opcode::Query,
aa: AuthoritativeAnswer::NonAuthoritative,
truncated: Truncated::NotTruncated,
recursion_desired: RecursionDesired::Desired,
recursion_available: RecursionAvailable::NotAvailable,
rcode: Rcode::NoError,
},
query_count: 1,
answer_count: 0,
name_server_count: 0,
additional_count: 0,
};
let mut buffer: Vec<u8> = Vec::with_capacity(12);
encode_message_header(&message_header, &mut buffer).unwrap();
assert_eq!(buffer, message_header_bytes);
}
}
37 changes: 37 additions & 0 deletions dns_encode/src/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use dns_types::query::*;
use std::io::{Error, Write};

pub fn encode_query<W: Write>(query: &Query, writer: &mut W) -> Result<(), Error> {
for label in query.name.as_slice() {
let label_bytes = label.as_bytes();
let label_len = [label_bytes.len() as u8];
writer.write_all(&label_len)?;
writer.write_all(label_bytes)?;
}
writer.write_all(b"\x00")?;
let qtype: u16 = query.query_type.into();
let qtype_bytes = qtype.to_be_bytes();
writer.write_all(&qtype_bytes)?;
let qclass: u16 = query.query_class.into();
let qclass_bytes = qclass.to_be_bytes();
writer.write_all(&qclass_bytes)?;
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_encode_query() {
let query_bytes = hex::decode("076578616d706c6503636f6d0000010001").unwrap();
let query = Query {
name: vec![String::from("example"), String::from("com")],
query_type: QueryType::A,
query_class: QueryClass::Internet,
};
let mut buffer: Vec<u8> = Vec::with_capacity(12);
encode_query(&query, &mut buffer).unwrap();
assert_eq!(buffer, query_bytes);
}
}
31 changes: 29 additions & 2 deletions dns_types/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub struct Query {
pub query_class: QueryClass,
}

#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum QueryType {
A,
NS,
Expand Down Expand Up @@ -40,7 +40,25 @@ impl Into<QueryType> for u16 {
}
}

#[derive(Debug, PartialEq)]
impl Into<u16> for QueryType {
fn into(self) -> u16 {
match self {
QueryType::A => 1,
QueryType::NS => 2,
QueryType::CNAME => 5,
QueryType::SOA => 6,
QueryType::WKS => 11,
QueryType::PTR => 12,
QueryType::MX => 15,
QueryType::SRV => 33,
QueryType::AAAA => 28,
QueryType::ANY => 255,
QueryType::Unknown(u) => u,
}
}
}

#[derive(Debug, PartialEq, Clone, Copy)]
pub enum QueryClass {
Internet,
Unknown(u16),
Expand All @@ -54,3 +72,12 @@ impl Into<QueryClass> for u16 {
}
}
}

impl Into<u16> for QueryClass {
fn into(self) -> u16 {
match self {
QueryClass::Internet => 1,
QueryClass::Unknown(u) => u,
}
}
}

0 comments on commit 981382e

Please sign in to comment.