diff --git a/Cargo.lock b/Cargo.lock index ae0860a..14f7c8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -188,7 +188,7 @@ dependencies = [ "heck 0.4.0", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.90", ] [[package]] @@ -269,7 +269,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.48", + "syn 2.0.90", ] [[package]] @@ -280,7 +280,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.48", + "syn 2.0.90", ] [[package]] @@ -455,7 +455,7 @@ dependencies = [ "pretty_assertions", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.90", ] [[package]] @@ -614,7 +614,7 @@ dependencies = [ "libc", "redox_syscall 0.5.7", "smallvec", - "windows-targets 0.52.0", + "windows-targets 0.52.6", ] [[package]] @@ -664,7 +664,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.90", ] [[package]] @@ -849,6 +849,15 @@ dependencies = [ "libc", ] +[[package]] +name = "simple-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "simpledns" version = "0.1.0" @@ -859,6 +868,7 @@ dependencies = [ "ratatui", "rusqlite", "shellexpand", + "simple-macros", "tabled", "yaml-rust", ] @@ -906,7 +916,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.48", + "syn 2.0.90", ] [[package]] @@ -922,9 +932,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" dependencies = [ "proc-macro2", "quote", @@ -971,7 +981,7 @@ checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.90", ] [[package]] @@ -1054,7 +1064,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.90", "wasm-bindgen-shared", ] @@ -1076,7 +1086,7 @@ checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.90", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1115,7 +1125,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.6", ] [[package]] @@ -1133,7 +1143,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.6", ] [[package]] @@ -1153,17 +1163,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -1174,9 +1185,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -1186,9 +1197,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -1198,9 +1209,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -1210,9 +1227,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -1222,9 +1239,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -1234,9 +1251,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -1246,9 +1263,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "yaml-rust" @@ -1282,5 +1299,5 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.90", ] diff --git a/Cargo.toml b/Cargo.toml index c030d2d..34c81e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,13 +14,14 @@ rand = "0.8.5" ratatui = { version = "0.29.0", optional = true } rusqlite = { version = "0.30.0", features = ["bundled"] } shellexpand = "3.1.0" +simple-macros = { path = "simple-macros" } tabled = "0.17.0" yaml-rust = "0.4" [features] -default = ["log_info", "log_warn", "log_error"] -log_info = [] -log_debug = [] -log_warn = [] +default = ["log_info"] +log_debug = ["log_info"] +log_info = ["log_warn"] +log_warn = ["log_error"] log_error = [] tui = ["dep:ratatui"] diff --git a/simple-macros/Cargo.toml b/simple-macros/Cargo.toml new file mode 100644 index 0000000..943a84d --- /dev/null +++ b/simple-macros/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "simple-macros" +version = "0.1.0" +edition = "2021" + +[lib] +name = "simple_macros" +path = "src/lib.rs" +proc-macro = true + +[dependencies] +quote = "*" +syn = { version = "2.0.55", features = ["full", "extra-traits"] } +proc-macro2 = "1.0.78" diff --git a/simple-macros/src/lib.rs b/simple-macros/src/lib.rs new file mode 100644 index 0000000..200d29c --- /dev/null +++ b/simple-macros/src/lib.rs @@ -0,0 +1,75 @@ +extern crate proc_macro; +use proc_macro::TokenStream; +use quote::quote; +use syn::{Attribute, Meta, ReturnType}; + +// TODO this shouldn't require the from type to derive clone +// TODO I want this to take an attr argument to control whether we generate FromIterator or not +#[proc_macro_attribute] +pub fn from(_attr: TokenStream, item: TokenStream) -> TokenStream { + let ast = syn::parse::(item).expect("The #[from] macro can only be applied to free-standing functions"); + + if ast.sig.inputs.len() != 1 || ast.sig.output == ReturnType::Default { + panic!("#[from] requires annotated function to have form fn (X) -> Y where X is any type and Y is a non-void type."); + } + + let attributes = ast.attrs.into_iter() + .filter(|attr| match &attr.meta { + Meta::Path(path) if path.is_ident("from") => false, + Meta::Path(path) if path.segments.len() == 1 => match path.segments.last() { + Some(segment) if segment.ident.to_string() == "from" => false, + _ => true + } + Meta::List(list) if list.path.is_ident("from") => false, + Meta::List(list) if list.path.segments.len() == 1 => match list.path.segments.last() { + Some(segment) if segment.ident.to_string() == "from" => false, + _ => true + } + _ => true + }) + .collect::>(); + + let to_type = match ast.sig.output { + ReturnType::Type(_, return_type) => return_type, + ReturnType::Default => panic!("We need a return type :(") + }; + let (from_arg_name, from_type) = match ast.sig.inputs.get(0) { + Some(syn::FnArg::Typed(base_arg)) => { + (base_arg.pat.clone(), base_arg.ty.clone()) + } + _ => panic!("Bad function argument!! Must be 1 non-receiver argument") + }; + + let function_body = ast.block.stmts; + + let generated = quote! { + #(#attributes)* + impl From<#from_type> for #to_type { + fn from(#from_arg_name: #from_type) -> Self { + #(#function_body)* + } + } + + #(#attributes)* + impl From<&#from_type> for #to_type { + fn from(#from_arg_name: &#from_type) -> Self { + #from_arg_name.clone().into() + } + } + + #(#attributes)* + impl FromIterator<#from_type> for Vec<#to_type> { + fn from_iter>(iter: T) -> Self { + iter.into_iter().collect() + } + } + + #(#attributes)* + impl<'from_iterator_lifetime> FromIterator<&'from_iterator_lifetime #from_type> for Vec<#to_type> { + fn from_iter>(iter: T) -> Self { + iter.into_iter().map(|x| <#from_type as Into<#to_type>>::into(x.clone())).collect() + } + } + }; + generated.into() +} diff --git a/src/cli.rs b/src/cli.rs index db86508..b6a09db 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -12,7 +12,7 @@ use crate::{dns_packet::{DnsQueryType, DnsRecord, DnsRecordA, DnsRecordAAAA, Dns pub fn add_record(args: RecordArgs, settings: DnsSettings) -> Result<(), Box> { let domain = args.domain.unwrap(); let query_type = args.query_type.unwrap().into(); - let preamble = DnsRecordPreamble::build(domain, query_type, args.class, args.ttl); + let preamble = DnsRecordPreamble::build(domain.clone(), query_type, args.class, args.ttl); let record = match query_type { DnsQueryType::Unknown(_) => panic!("Impossible state"), DnsQueryType::A => DnsRecord::A(DnsRecordA::new(preamble, Ipv4Addr::from_str(args.ip.unwrap().as_str()).expect("Couldn't parse ipv4 address"))), @@ -23,17 +23,18 @@ pub fn add_record(args: RecordArgs, settings: DnsSettings) -> Result<(), Box DnsRecord::DROP(DnsRecordDROP::new(preamble)), }; let database = SimpleDatabase::new(settings.database_file); - database.insert_record(record.clone(), false)?; - log_info!("Successfully added record: {:?}", record); + database.insert_record(record.clone())?; + log_debug!("Successfully added record: {:?}", record); + log_info!("Successfully added record [{:?}] {}", query_type, domain); Ok(()) } pub fn add_record_interactive(settings: DnsSettings) -> Result<(), Box> { let domain = get_input("Domain: ", None, "A domain is required.", |x| !x.is_empty()); - let query_type = DnsQueryType::from_string(get_input("Record Type: ", + let query_type = get_input("Record Type: ", None, "A record type is required [A, NS, CNAME, MX, AAAA, DROP]", - |x| ["A", "NS", "CNAME", "MX", "AAAA", "DROP"].contains(&x.to_uppercase().as_str())).as_str()); + |x| ["A", "NS", "CNAME", "MX", "AAAA", "DROP"].contains(&x.to_uppercase().as_str())).as_str().into(); let class = get_input("Class [default 1]: ", Some("1".to_string()), "A valid u16 must be supplied.", @@ -69,7 +70,7 @@ pub fn add_record_interactive(settings: DnsSettings) -> Result<(), Box DnsRecord::DROP(DnsRecordDROP::new(preamble)) }; let database = SimpleDatabase::new(settings.database_file); - database.insert_record(record.clone(), false)?; + database.insert_record(record.clone())?; log_info!("Successfully added record: {:?}", record); Ok(()) } diff --git a/src/dns_packet.rs b/src/dns_packet.rs index d115b60..1a607d3 100644 --- a/src/dns_packet.rs +++ b/src/dns_packet.rs @@ -1,8 +1,10 @@ use std::io::{Error, ErrorKind}; use std::net::Ipv4Addr; -#[cfg(feature = "tui")] -use ratatui::widgets::Row; +use chrono::{Local, DateTime, Duration}; +use simple_macros::from; + +use crate::utils::{domain_name_to_bytes, get_name_from_packet, get_u16, get_u32, u16_to_bytes, u32_to_bytes}; #[derive(Clone, Debug)] pub struct DnsPacket { @@ -41,13 +43,13 @@ impl DnsPacket { result.append(&mut q.to_bytes()); } for a in &self.answer_section { - result.append(&mut a.to_bytes()); + result.append(&mut a.into()); } for a in &self.authority_section { - result.append(&mut a.to_bytes()); + result.append(&mut a.into()); } for a in &self.additional_section { - result.append(&mut a.to_bytes()); + result.append(&mut a.into()); } result } @@ -183,9 +185,9 @@ pub enum DnsOpCode { DNSSO = 6, } -impl DnsOpCode { - pub fn to_num(&self) -> u8 { - match self { +impl From for u8 { + fn from(value: DnsOpCode) -> Self { + match value { DnsOpCode::IQUERY => 1, DnsOpCode::STATUS => 2, DnsOpCode::NOTIFY => 4, @@ -194,9 +196,11 @@ impl DnsOpCode { DnsOpCode::QUERY => 0, } } +} - pub fn from_num(num: u8) -> DnsOpCode { - match num { +impl From for DnsOpCode { + fn from(value: u8) -> Self { + match value { 1 => DnsOpCode::IQUERY, 2 => DnsOpCode::STATUS, 4 => DnsOpCode::NOTIFY, @@ -223,9 +227,9 @@ pub enum DnsResponseCode { DSOTYPENI = 11, } -impl DnsResponseCode { - pub fn to_num(&self) -> u8 { - match self { +impl From for u8 { + fn from(value: DnsResponseCode) -> Self { + match value { DnsResponseCode::NOERROR => 0, DnsResponseCode::FORMERR => 1, DnsResponseCode::SERVFAIL => 2, @@ -240,9 +244,11 @@ impl DnsResponseCode { DnsResponseCode::DSOTYPENI => 11, } } +} - pub fn from_num(num: u8) -> DnsResponseCode { - match num { +impl From for DnsResponseCode { + fn from(value: u8) -> Self { + match value { 0 => DnsResponseCode::NOERROR, 1 => DnsResponseCode::FORMERR, 2 => DnsResponseCode::SERVFAIL, @@ -308,11 +314,11 @@ impl DnsHeader { (self.recurse_desired as u8) | ((self.truncated_message as u8) << 1) | ((self.auth_answer as u8) << 2) - | ((self.op_code.to_num()) << 3) + | ((self.op_code as u8) << 3) | ((self.query_response as u8) << 7), ); result.push( - (self.response_code.to_num()) + (self.response_code as u8) | ((self.checking_disabled as u8) << 4) | ((self.authed_data as u8) << 5) | ((self.z as u8) << 6) @@ -331,7 +337,7 @@ impl DnsHeader { Ok(Self { id: get_u16(bytes, 0)?, query_response: ((bytes[2] >> 7) & 1) != 0, - op_code: DnsOpCode::from_num((bytes[2] >> 3) & 15), + op_code: DnsOpCode::from((bytes[2] >> 3) & 15), auth_answer: ((bytes[2] >> 2) & 1) != 0, truncated_message: ((bytes[2] >> 1) & 1) != 0, recurse_desired: (bytes[2] & 1) != 0, @@ -339,7 +345,7 @@ impl DnsHeader { checking_disabled: ((bytes[3] >> 6) & 1) != 0, authed_data: ((bytes[3] >> 7) & 5) != 0, z: ((bytes[3] >> 4) & 1) != 0, - response_code: DnsResponseCode::from_num(bytes[3] & 15), + response_code: DnsResponseCode::from(bytes[3] & 15), question_count: get_u16(bytes, 4)?, answer_count: get_u16(bytes, 6)?, authority_count: get_u16(bytes, 8)?, @@ -424,72 +430,32 @@ impl DnsRecord { DnsRecord::DROP(x) => x.preamble.clone(), } } +} - pub fn to_bytes(&self) -> Vec { - match self { - DnsRecord::Unknown(x) => x.to_bytes(), - DnsRecord::A(x) => x.to_bytes(), - DnsRecord::NS(x) => x.to_bytes(), - DnsRecord::CNAME(x) => x.to_bytes(), - DnsRecord::MX(x) => x.to_bytes(), - DnsRecord::AAAA(x) => x.to_bytes(), - DnsRecord::DROP(_) => Vec::new(), - } +#[from] +fn dns_record_to_vec_u8(value: DnsRecord) -> Vec { + match value { + DnsRecord::Unknown(x) => x.into(), + DnsRecord::A(x) => x.into(), + DnsRecord::NS(x) => x.into(), + DnsRecord::CNAME(x) => x.into(), + DnsRecord::MX(x) => x.into(), + DnsRecord::AAAA(x) => x.into(), + DnsRecord::DROP(_) => Vec::new(), } +} - #[cfg(feature = "tui")] - pub fn to_row(&self) -> Row<'_> { - match self { - DnsRecord::Unknown(dns_record_unknown) => todo!(), - DnsRecord::A(dns_record_a) => Row::new(vec![ - dns_record_a.preamble.query_type.into(), - dns_record_a.preamble.domain.to_string(), - dns_record_a.ip.to_string(), - "".to_owned(), - dns_record_a.preamble.ttl.to_string(), - dns_record_a.preamble.class.to_string(), - ]), - DnsRecord::NS(dns_record_ns) => Row::new(vec![ - dns_record_ns.preamble.query_type.into(), - dns_record_ns.preamble.domain.to_string(), - dns_record_ns.host.to_string(), - "".to_owned(), - dns_record_ns.preamble.ttl.to_string(), - dns_record_ns.preamble.class.to_string(), - ]), - DnsRecord::CNAME(dns_record_cname) => Row::new(vec![ - dns_record_cname.preamble.query_type.into(), - dns_record_cname.preamble.domain.to_string(), - dns_record_cname.host.to_string(), - "".to_owned(), - dns_record_cname.preamble.ttl.to_string(), - dns_record_cname.preamble.class.to_string(), - ]), - DnsRecord::MX(dns_record_mx) => Row::new(vec![ - dns_record_mx.preamble.query_type.into(), - dns_record_mx.preamble.domain.to_string(), - dns_record_mx.host.to_string(), - dns_record_mx.priority.to_string(), - dns_record_mx.preamble.ttl.to_string(), - dns_record_mx.preamble.class.to_string(), - ]), - DnsRecord::AAAA(dns_record_aaaa) => Row::new(vec![ - dns_record_aaaa.preamble.query_type.into(), - dns_record_aaaa.preamble.domain.to_string(), - dns_record_aaaa.ip.to_string(), - "".to_owned(), - dns_record_aaaa.preamble.ttl.to_string(), - dns_record_aaaa.preamble.class.to_string(), - ]), - DnsRecord::DROP(dns_record_drop) => Row::new(vec![ - dns_record_drop.preamble.query_type.into(), - dns_record_drop.preamble.domain.to_string(), - "".to_owned(), - "".to_owned(), - dns_record_drop.preamble.ttl.to_string(), - dns_record_drop.preamble.class.to_string(), - ]) - } +#[from] +#[cfg(feature = "tui")] +fn dns_record_to_ratatui_row(value: DnsRecord) -> ratatui::widgets::Row<'_> { + match value { + DnsRecord::Unknown(_) => todo!(), + DnsRecord::A(dns_record_a) => dns_record_a.into(), + DnsRecord::NS(dns_record_ns) => dns_record_ns.into(), + DnsRecord::CNAME(dns_record_cname) => dns_record_cname.into(), + DnsRecord::MX(dns_record_mx) => dns_record_mx.into(), + DnsRecord::AAAA(dns_record_aaaa) => dns_record_aaaa.into(), + DnsRecord::DROP(dns_record_drop) => dns_record_drop.into() } } @@ -528,20 +494,9 @@ impl DnsQueryType { x => DnsQueryType::Unknown(x), } } - - pub fn from_string(value: &str) -> DnsQueryType { - match value.to_uppercase().as_str() { - "A" => DnsQueryType::A, - "NS" => DnsQueryType::NS, - "CNAME" => DnsQueryType::CNAME, - "MX" => DnsQueryType::MX, - "AAAA" => DnsQueryType::AAAA, - "DROP" => DnsQueryType::DROP, - _ => DnsQueryType::Unknown(0), - } - } } +// FIXME from macro doesn't work :( impl From for DnsQueryType { fn from(value: String) -> Self { match value.to_uppercase().as_str() { @@ -556,6 +511,7 @@ impl From for DnsQueryType { } } +// FIXME from macro doesn't work :( impl From<&str> for DnsQueryType { fn from(value: &str) -> Self { match value.to_uppercase().as_str() { @@ -613,18 +569,17 @@ impl DnsRecordPreamble { len: 0 } } +} - pub fn to_bytes(&self) -> Vec { - let mut result = Vec::new(); - - result.append(&mut domain_name_to_bytes(self.domain.as_str())); - result.append(&mut u16_to_bytes(self.query_type.to_num())); - result.append(&mut u16_to_bytes(self.class)); - result.append(&mut u32_to_bytes(self.ttl)); - result.append(&mut u16_to_bytes(self.len)); - - result - } +#[from] +fn dns_record_preamble_to_vec_u8(value: DnsRecordPreamble) -> Vec { + let mut result = Vec::new(); + result.append(&mut domain_name_to_bytes(value.domain.as_str())); + result.append(&mut u16_to_bytes(value.query_type.to_num())); + result.append(&mut u16_to_bytes(value.class)); + result.append(&mut u32_to_bytes(value.ttl)); + result.append(&mut u16_to_bytes(value.len)); + result } #[derive(Clone, Debug)] @@ -638,16 +593,20 @@ impl DnsRecordUnknown { preamble.len = body.len() as u16; Self { preamble, body } } +} - pub fn to_bytes(&self) -> Vec { - let mut result = Vec::new(); - - result.append(&mut self.preamble.to_bytes()); - let mut body_bytes = self.body.clone(); - result.append(&mut body_bytes); +#[from] +fn dns_record_unknown_to_vec_u8(value: DnsRecordUnknown) -> Vec { + let mut result: Vec = value.preamble.into(); + let mut body_bytes = value.body; + result.append(&mut body_bytes); + result +} - result - } +#[from] +#[cfg(feature = "tui")] +fn dns_record_unknown_to_ratatui_row(_dns_record_unknown: DnsRecordUnknown) -> ratatui::widgets::Row<'_> { + todo!() } #[derive(Clone, Debug)] @@ -661,6 +620,19 @@ impl DnsRecordDROP { } } +#[from] +#[cfg(feature = "tui")] +fn dns_record_drop_to_ratatui_row(dns_record_drop: DnsRecordDROP) -> ratatui::widgets::Row<'_> { + ratatui::widgets::Row::new(vec![ + dns_record_drop.preamble.query_type.into(), + dns_record_drop.preamble.domain.to_string(), + "".to_owned(), + dns_record_drop.preamble.ttl.to_string(), + "".to_owned(), + dns_record_drop.preamble.class.to_string(), + ]) +} + #[derive(Clone, Debug)] pub struct DnsRecordA { pub preamble: DnsRecordPreamble, @@ -672,18 +644,29 @@ impl DnsRecordA { preamble.len = 4; Self { preamble, ip } } +} - pub fn to_bytes(&self) -> Vec { - let mut result = Vec::new(); - - result.append(&mut self.preamble.to_bytes()); - result.push(self.ip.octets()[0]); - result.push(self.ip.octets()[1]); - result.push(self.ip.octets()[2]); - result.push(self.ip.octets()[3]); +#[from] +fn dns_record_a_to_vec_u8(value: DnsRecordA) -> Vec { + let mut result: Vec = value.preamble.into(); + result.push(value.ip.octets()[0]); + result.push(value.ip.octets()[1]); + result.push(value.ip.octets()[2]); + result.push(value.ip.octets()[3]); + result +} - result - } +#[from] +#[cfg(feature = "tui")] +fn dns_record_a_to_ratatui_row(dns_record_a: DnsRecordA) -> ratatui::widgets::Row<'_> { + ratatui::widgets::Row::new(vec![ + dns_record_a.preamble.query_type.into(), + dns_record_a.preamble.domain.to_string(), + dns_record_a.ip.to_string(), + dns_record_a.preamble.ttl.to_string(), + "".to_owned(), + dns_record_a.preamble.class.to_string(), + ]) } #[derive(Clone, Debug)] @@ -698,15 +681,27 @@ impl DnsRecordNS { preamble.len = len as u16; Self { preamble, host } } +} - pub fn to_bytes(&self) -> Vec { - let mut result = Vec::new(); - - result.append(&mut self.preamble.to_bytes()); - result.append(&mut domain_name_to_bytes(self.host.as_str())); +#[from] +fn dns_record_ns_to_vec_u8(dns_record_ns: DnsRecordNS) -> Vec { + let mut result: Vec = dns_record_ns.preamble.into(); + let mut domain_bytes = domain_name_to_bytes(dns_record_ns.host.as_str()); + result.append(&mut domain_bytes); + result +} - result - } +#[from] +#[cfg(feature = "tui")] +fn dns_record_ns_to_ratatui_row(dns_record_ns: DnsRecordNS) -> ratatui::widgets::Row<'_> { + ratatui::widgets::Row::new(vec![ + dns_record_ns.preamble.query_type.into(), + dns_record_ns.preamble.domain.to_string(), + dns_record_ns.host.to_string(), + dns_record_ns.preamble.ttl.to_string(), + "".to_owned(), + dns_record_ns.preamble.class.to_string(), + ]) } #[derive(Clone, Debug)] @@ -721,15 +716,27 @@ impl DnsRecordCNAME { preamble.len = len; Self { preamble, host } } +} - pub fn to_bytes(&self) -> Vec { - let mut result = Vec::new(); - - result.append(&mut self.preamble.to_bytes()); - result.append(&mut domain_name_to_bytes(self.host.as_str())); +#[from] +fn dns_record_cname_to_vec_u8(dns_record_cname: DnsRecordCNAME) -> Vec { + let mut result: Vec = dns_record_cname.preamble.into(); + let mut domain_bytes = domain_name_to_bytes(dns_record_cname.host.as_str()); + result.append(&mut domain_bytes); + result +} - result - } +#[from] +#[cfg(feature = "tui")] +fn dns_record_cname_to_ratatui_row(dns_record_cname: DnsRecordCNAME) -> ratatui::widgets::Row<'_> { + ratatui::widgets::Row::new(vec![ + dns_record_cname.preamble.query_type.into(), + dns_record_cname.preamble.domain.to_string(), + dns_record_cname.host.to_string(), + dns_record_cname.preamble.ttl.to_string(), + "".to_owned(), + dns_record_cname.preamble.class.to_string(), + ]) } #[derive(Clone, Debug)] @@ -749,16 +756,28 @@ impl DnsRecordMX { host, } } +} - pub fn to_bytes(&self) -> Vec { - let mut result = Vec::new(); - - result.append(&mut self.preamble.to_bytes()); - result.append(&mut u16_to_bytes(self.priority)); - result.append(&mut domain_name_to_bytes(self.host.as_str())); +#[from] +fn dns_record_mx_to_vec_u8(dns_record_mx: DnsRecordMX) -> Vec { + let mut result: Vec = dns_record_mx.preamble.into(); + result.append(&mut u16_to_bytes(dns_record_mx.priority)); + let mut domain_bytes = domain_name_to_bytes(dns_record_mx.host.as_str()); + result.append(&mut domain_bytes); + result +} - result - } +#[from] +#[cfg(feature = "tui")] +fn dns_record_mx_to_ratatui_row(dns_record_mx: DnsRecordMX) -> ratatui::widgets::Row<'_> { + ratatui::widgets::Row::new(vec![ + dns_record_mx.preamble.query_type.into(), + dns_record_mx.preamble.domain.to_string(), + dns_record_mx.host.to_string(), + dns_record_mx.preamble.ttl.to_string(), + dns_record_mx.priority.to_string(), + dns_record_mx.preamble.class.to_string(), + ]) } #[derive(Clone, Debug)] @@ -772,127 +791,68 @@ impl DnsRecordAAAA { preamble.len = 4; Self { preamble, ip } } - - pub fn to_bytes(&self) -> Vec { - let mut result = Vec::new(); - - result.append(&mut self.preamble.to_bytes()); - result.push(self.ip.octets()[0]); - result.push(self.ip.octets()[1]); - result.push(self.ip.octets()[2]); - result.push(self.ip.octets()[3]); - - result - } } -pub fn domain_name_to_bytes(value: &str) -> Vec { - let splits = value.split('.'); - let mut result = Vec::new(); - for s in splits { - let length = s.len(); - result.push((length & 0xFF) as u8); - for b in s.as_bytes() { - result.push(b.clone()); - } - } - result.push(0x00); +#[from] +fn dns_record_aaaa_to_vec_u8(dns_record_aaaa: DnsRecordAAAA) -> Vec { + let mut result: Vec = dns_record_aaaa.preamble.into(); + result.push(dns_record_aaaa.ip.octets()[0]); + result.push(dns_record_aaaa.ip.octets()[1]); + result.push(dns_record_aaaa.ip.octets()[2]); + result.push(dns_record_aaaa.ip.octets()[3]); result } -pub fn get_name_from_packet( - bytes: &[u8], - start: usize, - depth: i32, -) -> Result<(String, usize), Error> { - if depth == 20 { - return Err(Error::new(ErrorKind::InvalidData, "Loop limit exceeded")); - } - - let mut result = "".to_string(); - let mut index = start; - let mut delim = ""; - loop { - let length_byte = bytes[index]; - if (length_byte & 0xC0) == 0xC0 { - let offset_byte = bytes[index + 1] as u16; - index += 2; - - let jump_index = (((length_byte as u16) ^ 0xC0) << 8) | offset_byte; - let (part, _) = get_name_from_packet(bytes, jump_index as usize, depth + 1)?; - result.push_str(part.as_str()); - break; - } else { - index += 1; - if length_byte == 0 { - break; - } - - result.push_str(delim); - delim = "."; - let end = index + (length_byte as usize); - result.push_str( - String::from_utf8(bytes[index..end].to_vec()) - .unwrap() - .to_lowercase() - .as_str(), - ); - index = end; - } - } - Ok((result, index)) -} - -pub fn u16_to_bytes(num: u16) -> Vec { - vec![((num >> 8) & 0xFF) as u8, (num & 0xFF) as u8] -} - -pub fn u32_to_bytes(num: u32) -> Vec { - vec![ - ((num >> 24) & 0xFF) as u8, - ((num >> 16) & 0xFF) as u8, - ((num >> 8) & 0xFF) as u8, - (num & 0xFF) as u8, - ] -} - -pub fn get_u16(bytes: &[u8], index: usize) -> Result { - if index <= bytes.len() - 2 { - Ok((bytes[index] as u16) << 8 | (bytes[index + 1] as u16)) - } else { - Err(Error::new( - ErrorKind::InvalidData, - "Not enough bytes to get a u16", - )) - } +#[from] +#[cfg(feature = "tui")] +fn from(dns_record_aaaa: DnsRecordAAAA) -> ratatui::widgets::Row<'_> { + ratatui::widgets::Row::new(vec![ + dns_record_aaaa.preamble.query_type.into(), + dns_record_aaaa.preamble.domain.to_string(), + dns_record_aaaa.ip.to_string(), + dns_record_aaaa.preamble.ttl.to_string(), + "".to_owned(), + dns_record_aaaa.preamble.class.to_string(), + ]) } -pub fn get_u32(bytes: &[u8], index: usize) -> Result { - if index <= bytes.len() - 4 { - Ok( - (bytes[index] as u32) << 24 - | (bytes[index + 1] as u32) << 16 - | (bytes[index + 2] as u32) << 8 - | (bytes[index + 3] as u32), - ) - } else { - Err(Error::new( - ErrorKind::InvalidData, - "Not enough bytes to get a u32", - )) - } +#[derive(Clone)] +pub struct CachedDnsRecord { + pub cached_time: DateTime, + pub record: DnsRecord } -pub fn print_hex(bytes: String) { - for i in bytes.as_bytes() { - print!("{:02X} ", i); +impl CachedDnsRecord { + pub fn new(record: DnsRecord, cached_time: DateTime) -> Self { + Self { + cached_time, + record + } } - println!(); } -pub fn print_hex_bytes(bytes: &Vec) { - for i in bytes { - print!("{:02X} ", i); - } - println!(); +#[from] +#[cfg(feature = "tui")] +fn cached_dns_record_to_ratatui_row(cached_dns_record: CachedDnsRecord) -> ratatui::widgets::Row<'_> { + let preamble = cached_dns_record.record.clone().get_preamble(); + let since_insert = Local::now() - cached_dns_record.cached_time; + let expires_in = Duration::seconds(preamble.ttl.into()) - since_insert; + ratatui::widgets::Row::new(vec![ + preamble.query_type.into(), + preamble.domain.to_string(), + match &cached_dns_record.record { + DnsRecord::A(dns_record_a) => dns_record_a.ip.to_string(), + DnsRecord::NS(dns_record_ns) => dns_record_ns.host.to_string(), + DnsRecord::CNAME(dns_record_cname) => dns_record_cname.host.to_string(), + DnsRecord::MX(dns_record_mx) => dns_record_mx.host.to_string(), + DnsRecord::AAAA(dns_record_aaaa) => dns_record_aaaa.ip.to_string(), + _ => String::new() + }, + match &cached_dns_record.record { + DnsRecord::MX(dns_record_mx) => dns_record_mx.priority.to_string(), + _ => String::new() + }, + format!("{} sec", expires_in.num_seconds()), + preamble.class.to_string(), + ]) } diff --git a/src/dns_resolver.rs b/src/dns_resolver.rs index d285aad..53c6af7 100644 --- a/src/dns_resolver.rs +++ b/src/dns_resolver.rs @@ -62,19 +62,23 @@ impl DnsResolver { } fn do_remote_lookup(&self, question: &DnsQuestion, packet: &mut DnsPacket) -> Result<(), Box> { + log_debug!("Doing remote lookup {:?} {:?}", question, packet); let server = (self.database.get_random_remote_lookup_server().unwrap(), 53); - let socket = UdpSocket::bind(("0.0.0.0:0"))?; + let socket = UdpSocket::bind("0.0.0.0:0")?; let mut remote_packet = DnsPacket::new(); remote_packet.header.recurse_desired = true; remote_packet.add_question(DnsQuestion::new(question.name.clone(), question.query_type)); - let remote_packet_bytes = packet.to_bytes(); - - socket.send_to(&remote_packet_bytes, server)?; + let remote_packet_bytes = remote_packet.to_bytes(); let mut res: [u8; 512] = [0; 512]; - socket.recv_from(&mut res)?; + + log_debug!("Sending {:?} to {:?}", packet, server); + let sent = socket.send_to(&remote_packet_bytes, server)?; + log_debug!("Sent {} bytes", sent); + let (received, source_addr) = socket.recv_from(&mut res)?; + log_info!("Received {} bytes from {:?}", received, source_addr); match DnsPacket::from_bytes(&res) { Ok(result) => { @@ -85,21 +89,21 @@ impl DnsResolver { for ans in result.answer_section { log_debug!("Answer: {:?}", ans); packet.answer_section.push(ans.clone()); - ignore_result_and_log_error!(self.database.insert_record(ans, true)); + ignore_result_and_log_error!(self.database.insert_cache_record(ans)); packet.header.answer_count += 1; } for auth in result.authority_section { log_debug!("Authority: {:?}", auth); packet.authority_section.push(auth.clone()); - ignore_result_and_log_error!(self.database.insert_record(auth, true)); + ignore_result_and_log_error!(self.database.insert_cache_record(auth)); packet.header.authority_count += 1; } for add in result.additional_section { log_debug!("Resource: {:?}", add); packet.additional_section.push(add.clone()); - ignore_result_and_log_error!(self.database.insert_record(add, true)); + ignore_result_and_log_error!(self.database.insert_cache_record(add)); packet.header.additional_count += 1; } } @@ -108,6 +112,7 @@ impl DnsResolver { packet.header.response_code = DnsResponseCode::SERVFAIL; } } + log_debug!("Exiting do_remote_lookup"); Ok(()) } @@ -134,15 +139,8 @@ impl DnsResolver { fn any_record_type(records: &Vec, record_type: DnsQueryType) -> bool { for r in records { - match r { - DnsRecord::Unknown(x) if x.preamble.query_type == record_type => return true, - DnsRecord::A(x) if x.preamble.query_type == record_type => return true, - DnsRecord::NS(x) if x.preamble.query_type == record_type => return true, - DnsRecord::CNAME(x) if x.preamble.query_type == record_type => return true, - DnsRecord::MX(x) if x.preamble.query_type == record_type => return true, - DnsRecord::AAAA(x) if x.preamble.query_type == record_type => return true, - DnsRecord::DROP(x) if x.preamble.query_type == record_type => return true, - _ => {} + if r.get_preamble().query_type == record_type { + return true } } false diff --git a/src/dns_server.rs b/src/dns_server.rs index e02846f..cf51ed3 100644 --- a/src/dns_server.rs +++ b/src/dns_server.rs @@ -7,7 +7,8 @@ use std::sync::mpsc::{channel, Sender}; use std::thread::Builder; use rand::random; -use crate::dns_packet::*; +use crate::utils::{get_u16, u16_to_bytes}; +use crate::{dns_packet::*, log_debug}; use crate::dns_resolver::DnsResolver; use crate::settings::DnsSettings; use crate::{ignore_result_and_log_error, ignore_result_or_log_error_continue, log_error, log_warn, return_result_or_log_error_continue}; @@ -34,7 +35,9 @@ impl DnsUdpServer { impl DnsServer for DnsUdpServer { fn run(self) -> Result<(), Error> { - let socket = UdpSocket::bind(("0.0.0.0", self.settings.listening_port))?; + let bind_addr = ("0.0.0.0", self.settings.listening_port); + log_debug!("UDP server listening at {:?}:{}", bind_addr.0, bind_addr.1); + let socket = UdpSocket::bind(bind_addr)?; for thread_num in 0..self.settings.thread_count { let request_queue = self.request_queue.clone(); @@ -133,7 +136,9 @@ impl DnsTcpServer { impl DnsServer for DnsTcpServer { fn run(mut self) -> Result<(), Error> { - let socket = TcpListener::bind(("0.0.0.0", self.settings.listening_port))?; + let bind_addr = ("0.0.0.0", self.settings.listening_port); + log_debug!("TCP server listening at {:?}:{}", bind_addr.0, bind_addr.1); + let socket = TcpListener::bind(bind_addr)?; for thread_id in 0..self.settings.thread_count { let (sender, receiver) = channel(); @@ -144,26 +149,46 @@ impl DnsServer for DnsTcpServer { let _ = Builder::new() .name(format!("DnsTcpServer-request-handler-{}", thread_id)) - .spawn(move || loop { - let mut stream = return_result_or_log_error_continue!(receiver.recv(), "Failed to receive the tcp stream"); - - - let mut packet_length_buffer = [0; 2]; - ignore_result_or_log_error_continue!(stream.read(&mut packet_length_buffer), "Failed to read the packet length from the stream"); + .spawn(move || { + let stream_receiver = receiver; + loop { + let mut stream = return_result_or_log_error_continue!(stream_receiver.recv(), "Failed to receive the tcp stream"); + log_debug!("TCP stream received!!!!!"); + + let mut packet_length_buffer = [0; 2]; + ignore_result_or_log_error_continue!(stream.read(&mut packet_length_buffer), "Failed to read the packet length from the stream"); + let packet_length: usize = match get_u16(&packet_length_buffer, 0) { + Ok(x) => x as usize, + Err(err) => { + log_error!("Failed packet length conversion: {}", err); + 0 // TODO this feels hacky but the return_result_or_log_error_continue macro wasn't working for some reason also this is very likely to not happen + } + }; - let mut packet_buffer = Vec::new(); - ignore_result_or_log_error_continue!(stream.read_to_end(&mut packet_buffer), "Failed to read the packet into a buffer"); + log_debug!("Read packet length: {:?}", packet_length); + let mut packet_buffer = vec![0; packet_length]; + log_debug!("Made packet buffer of size {}", packet_buffer.len()); + ignore_result_or_log_error_continue!(stream.read(&mut packet_buffer), "Failed to read the packet into a buffer"); - let request = return_result_or_log_error_continue!(DnsPacket::from_bytes(packet_buffer.as_slice()), "Failed to parse packet from buffer"); - let resolver = DnsResolver::new(settings.database_file.clone()); + log_debug!("Done reading to end of the stream"); + let request = return_result_or_log_error_continue!(DnsPacket::from_bytes(&packet_buffer), "Failed to parse packet from buffer"); + let resolver = DnsResolver::new(settings.database_file.clone()); - match resolver.answer_question(request) { - Ok(result) => { - ignore_result_or_log_error_continue!(stream.write(result.to_bytes().as_slice()), "Failed writing result back to buffer"); - ignore_result_or_log_error_continue!(stream.shutdown(Shutdown::Both), "Failed shutting down tcp connection"); - } - Err(error) => { - log_error!("Resolver error {}", error) + match resolver.answer_question(request) { + Ok(result) => { + log_debug!("Sending response packet: {:#?}", result); + let response_bytes = result.to_bytes(); + let response_length = response_bytes.len() as u16; // TODO this is a sketchy cast + ignore_result_or_log_error_continue!(stream.write(u16_to_bytes(response_length).as_slice()), "Failed writing result back to buffer"); + ignore_result_or_log_error_continue!(stream.write(response_bytes.as_slice()), "Failed writing result back to buffer"); + log_debug!("Flushing Stream Buffer..."); + ignore_result_or_log_error_continue!(stream.flush(), "Failed flushing tcp buffer"); + log_debug!("Shutting down stream..."); + ignore_result_or_log_error_continue!(stream.shutdown(Shutdown::Both), "Failed shutting down tcp connection"); + } + Err(error) => { + log_error!("Resolver error {:#?}", error) + } } } })?; @@ -172,7 +197,11 @@ impl DnsServer for DnsTcpServer { let _ = Builder::new().name("DnsTcpServer-incoming-requests".to_string()) .spawn(move || for incoming in socket.incoming() { match incoming { - Ok(stream) => ignore_result_and_log_error!(self.request_handlers[random::() % self.settings.thread_count as usize].send(stream)), + Ok(stream) => { + let idx = random::() % self.settings.thread_count as usize; + log_debug!("Picked handler {}/{}", idx, self.settings.thread_count); + ignore_result_and_log_error!(self.request_handlers[idx].send(stream)) + } Err(error) => log_error!("Failed to accept incoming TCP connection: {}", error), } })?; diff --git a/src/main.rs b/src/main.rs index b8d897f..e5a151a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ pub mod dns_server; mod macros; mod settings; mod simple_database; +mod utils; #[cfg(feature = "tui")] mod tui; @@ -128,7 +129,7 @@ fn main() -> Result<(), Box> { let database = SimpleDatabase::new(settings.database_file); match database.initialize() { - Ok(_) => log_debug!("Successfully initialized the database :)"), + Ok(_) => log_info!("Successfully initialized the database :)"), Err(error) => log_error!("There was an error while initializing the database :( | {}", error), } } @@ -173,7 +174,7 @@ fn main() -> Result<(), Box> { tui_start(&settings)?; } #[cfg(not(feature = "tui"))] - Commands::Tui { config } => { + Commands::Tui { .. } => { log_error!("simpledns was not built with the TUI feature :( please rebuild with `cargo build --features \"tui\"`...") } Commands::Add { config, interactive, .. } if interactive => { diff --git a/src/settings.rs b/src/settings.rs index 8b50f77..e835cb3 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -2,9 +2,6 @@ use std::error::Error; use std::fs; use std::io::ErrorKind; use yaml_rust::YamlLoader; -use std::path::Path; - -use crate::log_info; use crate::log_debug; diff --git a/src/simple_database.rs b/src/simple_database.rs index b3f23ab..86038b7 100644 --- a/src/simple_database.rs +++ b/src/simple_database.rs @@ -1,8 +1,8 @@ use crate::dns_packet::{ - DnsQueryType, DnsRecord, DnsRecordA, DnsRecordAAAA, DnsRecordCNAME, DnsRecordDROP, DnsRecordMX, - DnsRecordNS, DnsRecordPreamble, DnsRecordUnknown, + CachedDnsRecord, DnsQueryType, DnsRecord, DnsRecordA, DnsRecordAAAA, DnsRecordCNAME, DnsRecordDROP, DnsRecordMX, DnsRecordNS, DnsRecordPreamble, DnsRecordUnknown }; -use rusqlite::{params, Connection, Params, Result, Statement}; +use chrono::{Local, TimeZone}; +use rusqlite::{params, Connection, Params, Result, Statement, Row}; use std::net::Ipv4Addr; use std::str; use std::str::FromStr; @@ -22,44 +22,65 @@ impl SimpleDatabase { self.connection.execute("CREATE TABLE IF NOT EXISTS remote_lookup_servers(ip TEXT PRIMARY KEY)", [])?; self.connection.execute("INSERT INTO remote_lookup_servers VALUES (\"8.8.8.8\")", [])?; self.connection.execute("INSERT INTO remote_lookup_servers VALUES (\"75.75.75.75\")", [])?; - self.connection.execute("CREATE TABLE IF NOT EXISTS records(domain TEXT, query_type INTEGER, class INTEGER, ttl INTEGER, len INTEGER, hostipbody TEXT, priority INTEGER, cached INTEGER, insert_time INTEGER)", [])?; - self.connection.execute("CREATE UNIQUE INDEX IF NOT EXISTS record_unique_idx ON records(domain, query_type, hostipbody, priority, cached)", [])?; + self.connection.execute("CREATE TABLE IF NOT EXISTS cached_records(domain TEXT, query_type INTEGER, class INTEGER, ttl INTEGER, len INTEGER, hostipbody TEXT, priority INTEGER, insert_time INTEGER)", [])?; + self.connection.execute("CREATE UNIQUE INDEX IF NOT EXISTS cached_record_unique_idx ON cached_records(domain, query_type, hostipbody, priority)", [])?; + self.connection.execute("CREATE TABLE IF NOT EXISTS records(domain TEXT, query_type INTEGER, class INTEGER, ttl INTEGER, len INTEGER, hostipbody TEXT, priority INTEGER)", [])?; + self.connection.execute("CREATE UNIQUE INDEX IF NOT EXISTS record_unique_idx ON records(domain, query_type, hostipbody, priority)", [])?; Ok(()) } + fn row_to_dns_record(&self, row: &Row<'_>) -> Result { + let mut preamble = DnsRecordPreamble::new(); + preamble.domain = row.get(0)?; + preamble.query_type = DnsQueryType::from_num(row.get(1)?); + preamble.class = row.get(2)?; + preamble.ttl = row.get(3)?; + preamble.len = row.get(4)?; + Ok(match preamble.query_type { + DnsQueryType::Unknown(_) => DnsRecord::Unknown(DnsRecordUnknown::new( + preamble, + row.get::(5)?.into_bytes(), + )), + DnsQueryType::A => DnsRecord::A(DnsRecordA::new( + preamble, + Ipv4Addr::from_str(row.get::(5)?.as_str()).unwrap(), + )), + DnsQueryType::NS => DnsRecord::NS(DnsRecordNS::new(preamble, row.get::(5)?)), + DnsQueryType::CNAME => { + DnsRecord::CNAME(DnsRecordCNAME::new(preamble, row.get::(5)?)) + } + DnsQueryType::MX => DnsRecord::MX(DnsRecordMX::new( + preamble, + row.get::(5)?, + row.get::(5)?, + )), + DnsQueryType::AAAA => DnsRecord::AAAA(DnsRecordAAAA::new( + preamble, + Ipv4Addr::from_str(row.get::(5)?.as_str()).unwrap(), + )), + DnsQueryType::DROP => DnsRecord::DROP(DnsRecordDROP::new(preamble)), + }) + } + + fn row_to_cached_dns_record(&self, row: &Row<'_>) -> Result { + let record = self.row_to_dns_record(row)?; + let insert_timestamp = row.get(7)?; + let insert_time = Local.timestamp_opt(insert_timestamp, 0).unwrap(); + Ok(CachedDnsRecord::new(record, insert_time)) + } + fn run_dns_record_query(&self, mut statement: Statement<'_>, params: P) -> Result> { - let query_results = statement.query_map(params, |row| { - let mut preamble = DnsRecordPreamble::new(); - preamble.domain = row.get(0)?; - preamble.query_type = DnsQueryType::from_num(row.get(1)?); - preamble.class = row.get(2)?; - preamble.ttl = row.get(3)?; - preamble.len = row.get(4)?; - Ok(match preamble.query_type { - DnsQueryType::Unknown(_) => DnsRecord::Unknown(DnsRecordUnknown::new( - preamble, - row.get::(5)?.into_bytes(), - )), - DnsQueryType::A => DnsRecord::A(DnsRecordA::new( - preamble, - Ipv4Addr::from_str(row.get::(5)?.as_str()).unwrap(), - )), - DnsQueryType::NS => DnsRecord::NS(DnsRecordNS::new(preamble, row.get::(5)?)), - DnsQueryType::CNAME => { - DnsRecord::CNAME(DnsRecordCNAME::new(preamble, row.get::(5)?)) - } - DnsQueryType::MX => DnsRecord::MX(DnsRecordMX::new( - preamble, - row.get::(5)?, - row.get::(5)?, - )), - DnsQueryType::AAAA => DnsRecord::AAAA(DnsRecordAAAA::new( - preamble, - Ipv4Addr::from_str(row.get::(5)?.as_str()).unwrap(), - )), - DnsQueryType::DROP => DnsRecord::DROP(DnsRecordDROP::new(preamble)), - }) - })?; + let query_results = statement.query_map(params, |row| self.row_to_dns_record(row))?; + + let mut results = Vec::new(); + for record in query_results { + results.push(record?); + } + Ok(results) + } + + fn run_cached_dns_record_query(&self, mut statement: Statement<'_>, params: P) -> Result> { + let query_results = statement.query_map(params, |row| self.row_to_cached_dns_record(row))?; let mut results = Vec::new(); for record in query_results { @@ -69,7 +90,7 @@ impl SimpleDatabase { } fn clean_up_cache(&self) -> Result<()> { - self.connection.execute("DELETE FROM records WHERE records.cached AND records.ttl < unixepoch() - records.insert_time;", [])?; + self.connection.execute("DELETE FROM cached_records WHERE cached_records.ttl < unixepoch() - cached_records.insert_time;", [])?; Ok(()) } @@ -88,10 +109,55 @@ impl SimpleDatabase { pub fn get_records(&self, domain: String) -> Result> { self.clean_up_cache()?; let stmt = self.connection.prepare("SELECT domain, query_type, class, ttl, len, hostipbody, priority FROM records WHERE domain = ?1;")?; - self.run_dns_record_query(stmt, params![domain]) + let mut records = self.run_dns_record_query(stmt, params![domain])?; + let stmt = self.connection.prepare("SELECT domain, query_type, class, ttl, len, hostipbody, priority FROM cached_records WHERE domain = ?1;")?; + let mut cached_records = self.run_dns_record_query(stmt, params![domain])?; + records.append(&mut cached_records); + Ok(records) + } + + pub fn get_all_cached_records(&self) -> Result> { + self.clean_up_cache()?; + let stmt = self.connection.prepare("SELECT domain, query_type, class, ttl, len, hostipbody, priority, insert_time FROM cached_records;")?; + self.run_cached_dns_record_query(stmt, params![]) + } + + pub fn insert_record(&self, record: DnsRecord) -> Result<()> { + let preamble = record.get_preamble(); + let domain = preamble.domain; + let query_type = preamble.query_type.to_num().to_string(); + let class = preamble.class.to_string(); + let ttl = preamble.ttl.to_string(); + let len = preamble.len.to_string(); + let priority = match &record { + DnsRecord::Unknown(_) => 0, + DnsRecord::A(_) => 0, + DnsRecord::NS(_) => 0, + DnsRecord::CNAME(_) => 0, + DnsRecord::MX(mx) => mx.priority, + DnsRecord::AAAA(_) => 0, + DnsRecord::DROP(_) => 0, + } + .to_string(); + + let hostipbody = match &record { + DnsRecord::Unknown(record) => str::from_utf8(&*record.body).unwrap().to_string(), + DnsRecord::A(record) => record.ip.to_string(), + DnsRecord::NS(record) => record.host.clone(), + DnsRecord::CNAME(record) => record.host.clone(), + DnsRecord::MX(record) => record.host.clone(), + DnsRecord::AAAA(record) => record.ip.to_string(), + DnsRecord::DROP(_) => "".to_string(), + }; + + self.connection.execute( + "INSERT OR REPLACE INTO records VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7);", + (&domain, &query_type, &class, &ttl, &len, &hostipbody, &priority), + )?; + Ok(()) } - pub fn insert_record(&self, record: DnsRecord, cached_record: bool) -> Result<()> { + pub fn insert_cache_record(&self, record: DnsRecord) -> Result<()> { let preamble = record.get_preamble(); let domain = preamble.domain; let query_type = preamble.query_type.to_num().to_string(); @@ -120,8 +186,8 @@ impl SimpleDatabase { }; self.connection.execute( - "INSERT OR REPLACE INTO records VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, unixepoch());", - (&domain, &query_type, &class, &ttl, &len, &hostipbody, &priority, &cached_record), + "INSERT OR REPLACE INTO cached_records VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, unixepoch());", + (&domain, &query_type, &class, &ttl, &len, &hostipbody, &priority), )?; Ok(()) } diff --git a/src/tui/base.rs b/src/tui/base.rs index 8797a5d..02a8bb3 100644 --- a/src/tui/base.rs +++ b/src/tui/base.rs @@ -1,9 +1,7 @@ -use std::borrow::Borrow; use std::io::Result; -use std::thread::current; use ratatui::buffer::Buffer; -use ratatui::crossterm::event::{self, Event, KeyCode, KeyEvent, KeyEventKind, MouseEvent}; +use ratatui::crossterm::event::{self, KeyCode, KeyEventKind}; use ratatui::layout::{Constraint, Direction, Layout, Rect}; use ratatui::prelude::Stylize; use ratatui::style::Style; @@ -14,8 +12,8 @@ use ratatui::{DefaultTerminal, Frame}; use crate::settings::DnsSettings; use crate::log_debug; -use crate::simple_database::SimpleDatabase; +use super::cache_list_view::CacheListView; use super::event::{SimpleEvent, SimpleEventResult}; use super::record_list_view::RecordListView; use super::view::View; @@ -50,7 +48,6 @@ impl AppState { } struct App { - //simple_connection: SimpleDatabase, views: Vec>, exit: bool } @@ -58,8 +55,10 @@ struct App { impl App { pub fn new(settings: &DnsSettings) -> Self { Self { - //simple_connection: SimpleDatabase::new(settings.database_file.clone()), - views: vec![RecordListView::new_boxed(settings)], + views: vec![ + RecordListView::new_boxed(settings), + CacheListView::new_boxed(settings) + ], exit: false } } @@ -76,8 +75,8 @@ impl App { frame.render_stateful_widget(self, frame.area(), state); } - pub fn handle_events(&mut self, state: &AppState) -> Result<()> { - let mut current_view = &mut self.views[state.current_view()]; + pub fn handle_events(&mut self, state: &mut AppState) -> Result<()> { + let current_view = &mut self.views[state.current_view()]; match event::poll(current_view.poll_rate()) { Ok(true) => { let simple_event: SimpleEvent = event::read()?.into(); @@ -87,12 +86,19 @@ impl App { SimpleEvent::Key(key) if key.kind == KeyEventKind::Press && key.code == KeyCode::Esc => { self.exit = true; } + SimpleEvent::Key(key) if key.kind == KeyEventKind::Press => { + for (idx, view) in self.views.iter().enumerate() { + if key.code == view.open_view_control() { + state.selected_view = ListState::default().with_selected(Some(idx)) + } + } + } _ => {} } } } Ok(false) => { current_view.handle_event(SimpleEvent::Tick); } - Err(error) => {} // WHAT TO DO??? + Err(_) => {} // WHAT TO DO??? } Ok(()) } diff --git a/src/tui/cache_list_view.rs b/src/tui/cache_list_view.rs new file mode 100644 index 0000000..906dc12 --- /dev/null +++ b/src/tui/cache_list_view.rs @@ -0,0 +1,89 @@ +use std::time::Duration; + +use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::{Constraint, Rect}, text::{Line, Text}, widgets::{Block, Paragraph, Row, Table, Widget}}; +use ratatui::prelude::Stylize; +use ratatui::prelude::Style; + +use crate::{settings::DnsSettings, simple_database::SimpleDatabase}; + +use super::{event::{SimpleEvent, SimpleEventResult}, view::View}; + +pub struct CacheListView { + simple_database: SimpleDatabase +} + +impl CacheListView { + pub fn new(settings: &DnsSettings) -> Self { + Self { + simple_database: SimpleDatabase::new(settings.database_file.clone()) + } + } + + pub fn new_boxed(settings: &DnsSettings) -> Box { + Box::new(Self::new(settings)) + } +} + +impl View for CacheListView { + fn draw(&self, block: Block, area: Rect, buf: &mut Buffer) { + match self.simple_database.get_all_cached_records() { + Ok(records) => { + Table::default() + .rows(records.iter().collect::>>()) + .header(Row::new(vec!["Query Type", "Domain", "Host/IP", "Expires In", "Priority", "Class"]).underlined().cyan()) + .widths([ + Constraint::Length(12), + Constraint::Fill(1), + Constraint::Fill(1), + Constraint::Length(12), + Constraint::Length(10), + Constraint::Length(7) + ]) + .row_highlight_style(Style::new().underlined()) + .highlight_symbol("->") + .block(block) + .render(area, buf); + } + Err(err) => { + let text = vec![ + "ERROR GETTING LIST OF CACHED RECORDS FROM DB".into(), + err.to_string().into() + ]; + Paragraph::new(text) + .centered() + .red() + .bold() + .italic() + .block(block) + .render(area, buf); + } + } + } + + fn handle_event(&mut self, _: SimpleEvent) -> SimpleEventResult { + SimpleEventResult::Bubble + } + + fn open_view_control(&self) -> KeyCode { + KeyCode::Char('c') + } + + fn name(&self) -> Line { + Line::from(vec![ + " ".into(), + "C".red().bold(), + "ached Records".blue(), + " ".into() + ]) + } + + fn help(&self) -> Text { + Text::from(vec![ + "[ESC] - Exit SimpleDNS".into() + ]) + } + + fn poll_rate(&self) -> Duration { + Duration::from_secs(1) + } +} diff --git a/src/tui/event.rs b/src/tui/event.rs index 159bafb..b97fce3 100644 --- a/src/tui/event.rs +++ b/src/tui/event.rs @@ -1,4 +1,6 @@ use ratatui::crossterm::event::{Event, KeyEvent, MouseEvent}; +// TODO remove this as we start using more of these +#[allow(unused)] #[derive(Clone)] pub enum SimpleEvent { @@ -24,6 +26,7 @@ impl From for SimpleEvent { } pub enum SimpleEventResult { + #[allow(dead_code)] Consume, Bubble -} \ No newline at end of file +} diff --git a/src/tui/mod.rs b/src/tui/mod.rs index f966102..f61fa2a 100644 --- a/src/tui/mod.rs +++ b/src/tui/mod.rs @@ -3,4 +3,5 @@ extern crate ratatui; pub mod base; mod view; mod event; -mod record_list_view; \ No newline at end of file +mod record_list_view; +mod cache_list_view; diff --git a/src/tui/record_list_view.rs b/src/tui/record_list_view.rs index c47dc68..5a4f9b1 100644 --- a/src/tui/record_list_view.rs +++ b/src/tui/record_list_view.rs @@ -1,10 +1,10 @@ use std::time::Duration; -use ratatui::{buffer::Buffer, layout::Rect, text::{Line, Text}, widgets::{Block, Chart, Dataset, Paragraph, Row, Table, Widget}}; +use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::{Constraint, Rect}, text::{Line, Text}, widgets::{Block, Paragraph, Row, Table, Widget}}; use ratatui::prelude::Stylize; use ratatui::prelude::Style; -use crate::{dns_packet::DnsRecord, settings::DnsSettings, simple_database::SimpleDatabase}; +use crate::{settings::DnsSettings, simple_database::SimpleDatabase}; use super::{event::{SimpleEvent, SimpleEventResult}, view::View}; @@ -29,14 +29,22 @@ impl View for RecordListView { match self.simple_database.get_all_records() { Ok(records) => { Table::default() - .rows(records.iter().map(|x| x.to_row()).collect::>>()) - .header(Row::new(vec!["Query Type", "Domain", "Host/IP", "Priority", "TTL", "Class"]).underlined().cyan()) + .rows(records.iter().collect::>>()) // TODO There has to be a better way + .header(Row::new(vec!["Query Type", "Domain", "Host/IP", "TTL", "Priority", "Class"]).underlined().cyan()) + .widths([ + Constraint::Length(12), + Constraint::Fill(1), + Constraint::Fill(1), + Constraint::Length(12), + Constraint::Length(10), + Constraint::Length(7) + ]) .row_highlight_style(Style::new().underlined()) .highlight_symbol("->") .block(block) .render(area, buf); } - Err(error) => { + Err(_) => { Paragraph::new("ERROR GETTING LIST OF RECORDS FROM DB") .centered() .red() @@ -49,10 +57,14 @@ impl View for RecordListView { } - fn handle_event(&mut self, event: SimpleEvent) -> SimpleEventResult { + fn handle_event(&mut self, _: SimpleEvent) -> SimpleEventResult { SimpleEventResult::Bubble } + fn open_view_control(&self) -> KeyCode { + KeyCode::Char('r') + } + fn name(&self) -> Line { Line::from(vec![ " ".into(), diff --git a/src/tui/view.rs b/src/tui/view.rs index e2810ec..82acb72 100644 --- a/src/tui/view.rs +++ b/src/tui/view.rs @@ -1,12 +1,13 @@ use std::time::Duration; -use ratatui::{buffer::Buffer, layout::Rect, text::{Line, Text}, widgets::{Block, Widget}}; +use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::Rect, text::{Line, Text}, widgets::Block}; use super::event::{SimpleEvent, SimpleEventResult}; pub trait View { fn draw(&self, block: Block, area: Rect, buf: &mut Buffer); fn handle_event(&mut self, event: SimpleEvent) -> SimpleEventResult; + fn open_view_control(&self) -> KeyCode; fn name(&self) -> Line; fn help(&self) -> Text; fn poll_rate(&self) -> Duration; diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..fadb647 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,98 @@ +use std::io::{Error, ErrorKind}; + +pub fn domain_name_to_bytes(value: &str) -> Vec { + let splits = value.split('.'); + let mut result = Vec::new(); + for s in splits { + let length = s.len(); + result.push((length & 0xFF) as u8); + for b in s.as_bytes() { + result.push(b.clone()); + } + } + result.push(0x00); + result +} + +pub fn get_name_from_packet( + bytes: &[u8], + start: usize, + depth: i32, +) -> Result<(String, usize), Error> { + if depth == 20 { + return Err(Error::new(ErrorKind::InvalidData, "Loop limit exceeded")); + } + + let mut result = "".to_string(); + let mut index = start; + let mut delim = ""; + loop { + let length_byte = bytes[index]; + if (length_byte & 0xC0) == 0xC0 { + let offset_byte = bytes[index + 1] as u16; + index += 2; + + let jump_index = (((length_byte as u16) ^ 0xC0) << 8) | offset_byte; + let (part, _) = get_name_from_packet(bytes, jump_index as usize, depth + 1)?; + result.push_str(part.as_str()); + break; + } else { + index += 1; + if length_byte == 0 { + break; + } + + result.push_str(delim); + delim = "."; + let end = index + (length_byte as usize); + result.push_str( + String::from_utf8(bytes[index..end].to_vec()) + .unwrap() + .to_lowercase() + .as_str(), + ); + index = end; + } + } + Ok((result, index)) +} + +pub fn u16_to_bytes(num: u16) -> Vec { + vec![((num >> 8) & 0xFF) as u8, (num & 0xFF) as u8] +} + +pub fn u32_to_bytes(num: u32) -> Vec { + vec![ + ((num >> 24) & 0xFF) as u8, + ((num >> 16) & 0xFF) as u8, + ((num >> 8) & 0xFF) as u8, + (num & 0xFF) as u8, + ] +} + +pub fn get_u16(bytes: &[u8], index: usize) -> Result { + if index <= bytes.len() - 2 { + Ok((bytes[index] as u16) << 8 | (bytes[index + 1] as u16)) + } else { + Err(Error::new( + ErrorKind::InvalidData, + "Not enough bytes to get a u16", + )) + } +} + +pub fn get_u32(bytes: &[u8], index: usize) -> Result { + if index <= bytes.len() - 4 { + Ok( + (bytes[index] as u32) << 24 + | (bytes[index + 1] as u32) << 16 + | (bytes[index + 2] as u32) << 8 + | (bytes[index + 3] as u32), + ) + } else { + Err(Error::new( + ErrorKind::InvalidData, + "Not enough bytes to get a u32", + )) + } +}