Skip to content

Commit

Permalink
Finished the database / TUI setup for viewing cached records
Browse files Browse the repository at this point in the history
  • Loading branch information
jmeaster30 committed Dec 12, 2024
1 parent 3acf555 commit 5e7612d
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 57 deletions.
10 changes: 3 additions & 7 deletions simple-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
extern crate proc_macro;
use proc_macro::TokenStream;
use std::any::Any;
use std::cmp::min;
use std::str::FromStr;
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{parse_macro_input, Attribute, Data, Expr, Fields, Lit, Meta, ReturnType, Token};
use syn::parse::Parse;
use quote::quote;
use syn::{Attribute, Meta, ReturnType};

// TODO this shouldn't require the from type to derive clone
// TODO I want this to take an attr argument to control whether we generate FromIterator or not
#[proc_macro_attribute]
pub fn from(attr: TokenStream, item: TokenStream) -> TokenStream {
pub fn from(_attr: TokenStream, item: TokenStream) -> TokenStream {
let ast = syn::parse::<syn::ItemFn>(item).expect("The #[from] macro can only be applied to free-standing functions");

if ast.sig.inputs.len() != 1 || ast.sig.output == ReturnType::Default {
Expand Down
46 changes: 32 additions & 14 deletions src/dns_packet.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::io::{Error, ErrorKind};
use std::net::Ipv4Addr;

use rusqlite::Row;

use chrono::{Local, DateTime, Duration};
use simple_macros::from;

use crate::utils::{domain_name_to_bytes, get_name_from_packet, get_u16, get_u32, u16_to_bytes, u32_to_bytes};
Expand Down Expand Up @@ -450,7 +449,7 @@ fn dns_record_to_vec_u8(value: DnsRecord) -> Vec<u8> {
#[cfg(feature = "tui")]
fn dns_record_to_ratatui_row(value: DnsRecord) -> ratatui::widgets::Row<'_> {
match value {
DnsRecord::Unknown(dns_record_unknown) => todo!(),
DnsRecord::Unknown(_) => todo!(),
DnsRecord::A(dns_record_a) => dns_record_a.into(),
DnsRecord::NS(dns_record_ns) => dns_record_ns.into(),
DnsRecord::CNAME(dns_record_cname) => dns_record_cname.into(),
Expand Down Expand Up @@ -606,7 +605,7 @@ fn dns_record_unknown_to_vec_u8(value: DnsRecordUnknown) -> Vec<u8> {

#[from]
#[cfg(feature = "tui")]
fn dns_record_unknown_to_ratatui_row(value: DnsRecordUnknown) -> ratatui::widgets::Row<'_> {
fn dns_record_unknown_to_ratatui_row(_dns_record_unknown: DnsRecordUnknown) -> ratatui::widgets::Row<'_> {
todo!()
}

Expand All @@ -628,8 +627,8 @@ fn dns_record_drop_to_ratatui_row(dns_record_drop: DnsRecordDROP) -> ratatui::wi
dns_record_drop.preamble.query_type.into(),
dns_record_drop.preamble.domain.to_string(),
"".to_owned(),
"".to_owned(),
dns_record_drop.preamble.ttl.to_string(),
"".to_owned(),
dns_record_drop.preamble.class.to_string(),
])
}
Expand Down Expand Up @@ -664,8 +663,8 @@ fn dns_record_a_to_ratatui_row(dns_record_a: DnsRecordA) -> ratatui::widgets::Ro
dns_record_a.preamble.query_type.into(),
dns_record_a.preamble.domain.to_string(),
dns_record_a.ip.to_string(),
"".to_owned(),
dns_record_a.preamble.ttl.to_string(),
"".to_owned(),
dns_record_a.preamble.class.to_string(),
])
}
Expand Down Expand Up @@ -699,8 +698,8 @@ fn dns_record_ns_to_ratatui_row(dns_record_ns: DnsRecordNS) -> ratatui::widgets:
dns_record_ns.preamble.query_type.into(),
dns_record_ns.preamble.domain.to_string(),
dns_record_ns.host.to_string(),
"".to_owned(),
dns_record_ns.preamble.ttl.to_string(),
"".to_owned(),
dns_record_ns.preamble.class.to_string(),
])
}
Expand Down Expand Up @@ -734,8 +733,8 @@ fn dns_record_cname_to_ratatui_row(dns_record_cname: DnsRecordCNAME) -> ratatui:
dns_record_cname.preamble.query_type.into(),
dns_record_cname.preamble.domain.to_string(),
dns_record_cname.host.to_string(),
"".to_owned(),
dns_record_cname.preamble.ttl.to_string(),
"".to_owned(),
dns_record_cname.preamble.class.to_string(),
])
}
Expand Down Expand Up @@ -775,8 +774,8 @@ fn dns_record_mx_to_ratatui_row(dns_record_mx: DnsRecordMX) -> ratatui::widgets:
dns_record_mx.preamble.query_type.into(),
dns_record_mx.preamble.domain.to_string(),
dns_record_mx.host.to_string(),
dns_record_mx.priority.to_string(),
dns_record_mx.preamble.ttl.to_string(),
dns_record_mx.priority.to_string(),
dns_record_mx.preamble.class.to_string(),
])
}
Expand Down Expand Up @@ -811,20 +810,20 @@ fn from(dns_record_aaaa: DnsRecordAAAA) -> ratatui::widgets::Row<'_> {
dns_record_aaaa.preamble.query_type.into(),
dns_record_aaaa.preamble.domain.to_string(),
dns_record_aaaa.ip.to_string(),
"".to_owned(),
dns_record_aaaa.preamble.ttl.to_string(),
"".to_owned(),
dns_record_aaaa.preamble.class.to_string(),
])
}

#[derive(Clone)]
pub struct CachedDnsRecord {
pub cached_time: u32,
pub cached_time: DateTime<Local>,
pub record: DnsRecord
}

impl CachedDnsRecord {
pub fn new(record: DnsRecord, cached_time: u32) -> Self {
pub fn new(record: DnsRecord, cached_time: DateTime<Local>) -> Self {
Self {
cached_time,
record
Expand All @@ -835,6 +834,25 @@ impl CachedDnsRecord {
#[from]
#[cfg(feature = "tui")]
fn cached_dns_record_to_ratatui_row(cached_dns_record: CachedDnsRecord) -> ratatui::widgets::Row<'_> {
// TODO implement this properly to calculate the "Expires In" field
cached_dns_record.record.into()
let preamble = cached_dns_record.record.clone().get_preamble();
let since_insert = Local::now() - cached_dns_record.cached_time;
let expires_in = Duration::seconds(preamble.ttl.into()) - since_insert;
ratatui::widgets::Row::new(vec![
preamble.query_type.into(),
preamble.domain.to_string(),
match &cached_dns_record.record {
DnsRecord::A(dns_record_a) => dns_record_a.ip.to_string(),
DnsRecord::NS(dns_record_ns) => dns_record_ns.host.to_string(),
DnsRecord::CNAME(dns_record_cname) => dns_record_cname.host.to_string(),
DnsRecord::MX(dns_record_mx) => dns_record_mx.host.to_string(),
DnsRecord::AAAA(dns_record_aaaa) => dns_record_aaaa.ip.to_string(),
_ => String::new()
},
match &cached_dns_record.record {
DnsRecord::MX(dns_record_mx) => dns_record_mx.priority.to_string(),
_ => String::new()
},
format!("{} sec", expires_in.num_seconds()),
preamble.class.to_string(),
])
}
23 changes: 21 additions & 2 deletions src/simple_database.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::dns_packet::{
CachedDnsRecord, DnsQueryType, DnsRecord, DnsRecordA, DnsRecordAAAA, DnsRecordCNAME, DnsRecordDROP, DnsRecordMX, DnsRecordNS, DnsRecordPreamble, DnsRecordUnknown
};
use chrono::{Local, TimeZone};
use rusqlite::{params, Connection, Params, Result, Statement, Row};
use std::net::Ipv4Addr;
use std::str;
Expand Down Expand Up @@ -61,6 +62,13 @@ impl SimpleDatabase {
})
}

fn row_to_cached_dns_record(&self, row: &Row<'_>) -> Result<CachedDnsRecord> {
let record = self.row_to_dns_record(row)?;
let insert_timestamp = row.get(7)?;
let insert_time = Local.timestamp_opt(insert_timestamp, 0).unwrap();
Ok(CachedDnsRecord::new(record, insert_time))
}

fn run_dns_record_query<P: Params>(&self, mut statement: Statement<'_>, params: P) -> Result<Vec<DnsRecord>> {
let query_results = statement.query_map(params, |row| self.row_to_dns_record(row))?;

Expand All @@ -71,6 +79,16 @@ impl SimpleDatabase {
Ok(results)
}

fn run_cached_dns_record_query<P: Params>(&self, mut statement: Statement<'_>, params: P) -> Result<Vec<CachedDnsRecord>> {
let query_results = statement.query_map(params, |row| self.row_to_cached_dns_record(row))?;

let mut results = Vec::new();
for record in query_results {
results.push(record?);
}
Ok(results)
}

fn clean_up_cache(&self) -> Result<()> {
self.connection.execute("DELETE FROM cached_records WHERE cached_records.ttl < unixepoch() - cached_records.insert_time;", [])?;
Ok(())
Expand Down Expand Up @@ -99,8 +117,9 @@ impl SimpleDatabase {
}

pub fn get_all_cached_records(&self) -> Result<Vec<CachedDnsRecord>> {
//todo!();
Ok(Vec::new())
self.clean_up_cache()?;
let stmt = self.connection.prepare("SELECT domain, query_type, class, ttl, len, hostipbody, priority, insert_time FROM cached_records;")?;
self.run_cached_dns_record_query(stmt, params![])
}

pub fn insert_record(&self, record: DnsRecord) -> Result<()> {
Expand Down
9 changes: 3 additions & 6 deletions src/tui/base.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::borrow::Borrow;
use std::io::Result;
use std::thread::current;

use ratatui::buffer::Buffer;
use ratatui::crossterm::event::{self, Event, KeyCode, KeyEvent, KeyEventKind, MouseEvent};
use ratatui::crossterm::event::{self, KeyCode, KeyEventKind};
use ratatui::layout::{Constraint, Direction, Layout, Rect};
use ratatui::prelude::Stylize;
use ratatui::style::Style;
Expand All @@ -14,7 +12,6 @@ use ratatui::{DefaultTerminal, Frame};

use crate::settings::DnsSettings;
use crate::log_debug;
use crate::simple_database::SimpleDatabase;

use super::cache_list_view::CacheListView;
use super::event::{SimpleEvent, SimpleEventResult};
Expand Down Expand Up @@ -79,7 +76,7 @@ impl App {
}

pub fn handle_events(&mut self, state: &mut AppState) -> Result<()> {
let mut current_view = &mut self.views[state.current_view()];
let current_view = &mut self.views[state.current_view()];
match event::poll(current_view.poll_rate()) {
Ok(true) => {
let simple_event: SimpleEvent = event::read()?.into();
Expand All @@ -101,7 +98,7 @@ impl App {
}
}
Ok(false) => { current_view.handle_event(SimpleEvent::Tick); }
Err(error) => {} // WHAT TO DO???
Err(_) => {} // WHAT TO DO???
}
Ok(())
}
Expand Down
25 changes: 18 additions & 7 deletions src/tui/cache_list_view.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::time::Duration;

use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::Rect, text::{Line, Text}, widgets::{Block, Chart, Dataset, Paragraph, Row, Table, Widget}};
use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::{Constraint, Rect}, text::{Line, Text}, widgets::{Block, Paragraph, Row, Table, Widget}};
use ratatui::prelude::Stylize;
use ratatui::prelude::Style;

use crate::{dns_packet::DnsRecord, settings::DnsSettings, simple_database::SimpleDatabase};
use crate::{settings::DnsSettings, simple_database::SimpleDatabase};

use super::{event::{SimpleEvent, SimpleEventResult}, view::View};

Expand All @@ -30,14 +30,26 @@ impl View for CacheListView {
Ok(records) => {
Table::default()
.rows(records.iter().collect::<Vec<Row<'_>>>())
.header(Row::new(vec!["Query Type", "Domain", "Host/IP", "Priority", "Expires In", "Class"]).underlined().cyan())
.header(Row::new(vec!["Query Type", "Domain", "Host/IP", "Expires In", "Priority", "Class"]).underlined().cyan())
.widths([
Constraint::Length(12),
Constraint::Fill(1),
Constraint::Fill(1),
Constraint::Length(12),
Constraint::Length(10),
Constraint::Length(7)
])
.row_highlight_style(Style::new().underlined())
.highlight_symbol("->")
.block(block)
.render(area, buf);
}
Err(error) => {
Paragraph::new("ERROR GETTING LIST OF CACHED RECORDS FROM DB")
Err(err) => {
let text = vec![
"ERROR GETTING LIST OF CACHED RECORDS FROM DB".into(),
err.to_string().into()
];
Paragraph::new(text)
.centered()
.red()
.bold()
Expand All @@ -46,10 +58,9 @@ impl View for CacheListView {
.render(area, buf);
}
}

}

fn handle_event(&mut self, event: SimpleEvent) -> SimpleEventResult {
fn handle_event(&mut self, _: SimpleEvent) -> SimpleEventResult {
SimpleEventResult::Bubble
}

Expand Down
3 changes: 3 additions & 0 deletions src/tui/event.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use ratatui::crossterm::event::{Event, KeyEvent, MouseEvent};
// TODO remove this as we start using more of these
#[allow(unused)]

#[derive(Clone)]
pub enum SimpleEvent {
Expand All @@ -24,6 +26,7 @@ impl From<Event> for SimpleEvent {
}

pub enum SimpleEventResult {
#[allow(dead_code)]
Consume,
Bubble
}
18 changes: 13 additions & 5 deletions src/tui/record_list_view.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::time::Duration;

use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::Rect, text::{Line, Text}, widgets::{Block, Chart, Dataset, Paragraph, Row, Table, Widget}};
use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::{Constraint, Rect}, text::{Line, Text}, widgets::{Block, Paragraph, Row, Table, Widget}};
use ratatui::prelude::Stylize;
use ratatui::prelude::Style;

use crate::{dns_packet::DnsRecord, settings::DnsSettings, simple_database::SimpleDatabase};
use crate::{settings::DnsSettings, simple_database::SimpleDatabase};

use super::{event::{SimpleEvent, SimpleEventResult}, view::View};

Expand All @@ -30,13 +30,21 @@ impl View for RecordListView {
Ok(records) => {
Table::default()
.rows(records.iter().collect::<Vec<Row<'_>>>()) // TODO There has to be a better way
.header(Row::new(vec!["Query Type", "Domain", "Host/IP", "Priority", "TTL", "Class"]).underlined().cyan())
.header(Row::new(vec!["Query Type", "Domain", "Host/IP", "TTL", "Priority", "Class"]).underlined().cyan())
.widths([
Constraint::Length(12),
Constraint::Fill(1),
Constraint::Fill(1),
Constraint::Length(12),
Constraint::Length(10),
Constraint::Length(7)
])
.row_highlight_style(Style::new().underlined())
.highlight_symbol("->")
.block(block)
.render(area, buf);
}
Err(error) => {
Err(_) => {
Paragraph::new("ERROR GETTING LIST OF RECORDS FROM DB")
.centered()
.red()
Expand All @@ -49,7 +57,7 @@ impl View for RecordListView {

}

fn handle_event(&mut self, event: SimpleEvent) -> SimpleEventResult {
fn handle_event(&mut self, _: SimpleEvent) -> SimpleEventResult {
SimpleEventResult::Bubble
}

Expand Down
2 changes: 1 addition & 1 deletion src/tui/view.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::Duration;

use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::Rect, text::{Line, Text}, widgets::{Block, Widget}};
use ratatui::{buffer::Buffer, crossterm::event::KeyCode, layout::Rect, text::{Line, Text}, widgets::Block};

use super::event::{SimpleEvent, SimpleEventResult};

Expand Down
15 changes: 0 additions & 15 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::io::{Error, ErrorKind};
use std::net::Ipv4Addr;

pub fn domain_name_to_bytes(value: &str) -> Vec<u8> {
let splits = value.split('.');
Expand Down Expand Up @@ -97,17 +96,3 @@ pub fn get_u32(bytes: &[u8], index: usize) -> Result<u32, Error> {
))
}
}

pub fn print_hex(bytes: String) {
for i in bytes.as_bytes() {
print!("{:02X} ", i);
}
println!();
}

pub fn print_hex_bytes(bytes: &Vec<u8>) {
for i in bytes {
print!("{:02X} ", i);
}
println!();
}

0 comments on commit 5e7612d

Please sign in to comment.