diff --git a/src/app/app_execution.rs b/src/app/app_execution.rs index a94951b..b021024 100644 --- a/src/app/app_execution.rs +++ b/src/app/app_execution.rs @@ -18,25 +18,69 @@ //! [`AppExecution`]: Handles executing queries for the TUI application. use crate::app::state::tabs::sql::Query; -use crate::app::AppEvent; +use crate::app::{AppEvent, ExecutionError, ExecutionResultsBatch}; use crate::execution::ExecutionContext; +use arrow_flight::decode::FlightRecordBatchStream; use color_eyre::eyre::Result; +use datafusion::arrow::array::RecordBatch; +use datafusion::execution::context::SessionContext; +use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion::physical_plan::{ + execute_stream, visit_execution_plan, ExecutionPlan, ExecutionPlanVisitor, +}; use futures::StreamExt; use log::{error, info}; +use std::fmt::Debug; +use std::pin::Pin; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::Mutex; +use tonic::IntoRequest; + +#[cfg(feature = "flightsql")] +use {arrow_flight::sql::client::FlightSqlServiceClient, tonic::transport::Channel}; /// Handles executing queries for the TUI application, formatting results /// and sending them to the UI. pub(crate) struct AppExecution { inner: Arc, + // TODO: Store the SQL with the stream + result_stream: Arc>>, + // TODO: Store the SQL with the stream + flight_result_stream: Arc>>, + flight_results_current_row_start: Option, } impl AppExecution { /// Create a new instance of [`AppExecution`]. pub fn new(inner: Arc) -> Self { - Self { inner } + Self { + inner, + result_stream: Arc::new(Mutex::new(None)), + flight_result_stream: Arc::new(Mutex::new(None)), + flight_results_current_row_start: None, + } + } + + pub fn session_ctx(&self) -> &SessionContext { + self.inner.session_ctx() + } + + #[cfg(feature = "flightsql")] + pub fn flightsql_client(&self) -> &Mutex>> { + self.inner.flightsql_client() + } + + pub async fn set_result_stream(&self, stream: SendableRecordBatchStream) { + let mut s = self.result_stream.lock().await; + *s = Some(stream) + } + + #[cfg(feature = "flightsql")] + pub async fn set_flight_result_stream(&self, stream: FlightRecordBatchStream) { + let mut s = self.flight_result_stream.lock().await; + *s = Some(stream) } /// Run the sequence of SQL queries, sending the results as [`AppEvent::QueryResult`] via the sender. @@ -60,33 +104,53 @@ impl AppExecution { let start = std::time::Instant::now(); if i == statement_count - 1 { info!("Executing last query and display results"); - match self.inner.execute_sql(sql).await { - Ok(mut stream) => { - let mut batches = Vec::new(); - while let Some(maybe_batch) = stream.next().await { - match maybe_batch { - Ok(batch) => { - batches.push(batch); - } - Err(e) => { - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); - break; + sender.send(AppEvent::NewExecution)?; + match self.inner.create_physical_plan(sql).await { + Ok(plan) => match execute_stream(plan, self.inner.session_ctx().task_ctx()) { + Ok(stream) => { + self.set_result_stream(stream).await; + let mut stream = self.result_stream.lock().await; + if let Some(s) = stream.as_mut() { + if let Some(b) = s.next().await { + match b { + Ok(b) => { + let duration = start.elapsed(); + let results = ExecutionResultsBatch { + query: sql.to_string(), + batch: b, + duration, + }; + sender.send(AppEvent::ExecutionResultsNextPage( + results, + ))?; + } + Err(e) => { + error!("Error getting RecordBatch: {:?}", e); + } + } } } } + Err(stream_err) => { + error!("Error creating physical plan: {:?}", stream_err); + let elapsed = start.elapsed(); + let e = ExecutionError { + query: sql.to_string(), + error: stream_err.to_string(), + duration: elapsed, + }; + sender.send(AppEvent::ExecutionResultsError(e))?; + } + }, + Err(plan_err) => { + error!("Error creating physical plan: {:?}", plan_err); let elapsed = start.elapsed(); - let rows: usize = batches.iter().map(|r| r.num_rows()).sum(); - query.set_results(Some(batches)); - query.set_num_rows(Some(rows)); - query.set_execution_time(elapsed); - } - Err(e) => { - error!("Error creating dataframe: {:?}", e); - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); + let e = ExecutionError { + query: sql.to_string(), + error: plan_err.to_string(), + duration: elapsed, + }; + sender.send(AppEvent::ExecutionResultsError(e))?; } } } else { @@ -107,4 +171,153 @@ impl AppExecution { } Ok(()) } + + pub async fn run_flightsqls(&self, sqls: Vec<&str>, sender: UnboundedSender) { + info!("Running sqls: {:?}", sqls); + let non_empty_sqls: Vec<&str> = sqls.into_iter().filter(|s| !s.is_empty()).collect(); + info!("Non empty SQLs: {:?}", non_empty_sqls); + let statement_count = non_empty_sqls.len(); + for (i, sql) in non_empty_sqls.into_iter().enumerate() { + let client = self.flightsql_client(); + // let mut query = + // FlightSQLQuery::new(sql.clone(), None, None, None, Duration::default(), None); + let start = Instant::now(); + if let Some(ref mut c) = *client.lock().await { + match c.execute(sql.to_string(), None).await { + Ok(flight_info) => { + for endpoint in flight_info.endpoint { + if let Some(ticket) = endpoint.ticket { + match c.do_get(ticket.into_request()).await { + Ok(stream) => { + self.set_flight_result_stream(stream).await; + let mut stream = self.flight_result_stream.lock().await; + if let Some(s) = stream.as_mut() { + if let Some(b) = s.next().await { + match b { + Ok(b) => { + let results = ExecutionResultsBatch { + query: sql.to_string(), + batch: b, + duration: start.elapsed(), + }; + let _ = sender.send(AppEvent::FlightSQLExecutionResultsNextPage(results)); + } + Err(e) => { + error!( + "Error getting RecordBatch: {:?}", + e + ); + let e = ExecutionError { + query: sql.to_string(), + error: e.to_string(), + duration: start.elapsed(), + }; + let _ = sender.send(AppEvent::FlightSQLExecutionResultsError(e)); + } + } + } + } + } + Err(e) => { + error!("Error getting RecordBatch: {:?}", e); + } + } + } else { + error!("No ticket in endpoint"); + } + } + } + Err(e) => { + error!("Error executing FlightSQL query: {:?}", e); + } + } + } + } + } + + pub async fn next_batch(&self, sql: String, sender: UnboundedSender) { + let mut stream = self.result_stream.lock().await; + if let Some(s) = stream.as_mut() { + let start = std::time::Instant::now(); + if let Some(b) = s.next().await { + match b { + Ok(b) => { + let duration = start.elapsed(); + let results = ExecutionResultsBatch { + query: sql, + batch: b, + duration, + }; + let _ = sender.send(AppEvent::ExecutionResultsNextPage(results)); + } + Err(e) => { + error!("Error getting RecordBatch: {:?}", e); + } + } + } + } + } +} + +// #[derive(Debug, Clone)] +// pub struct ExecMetrics { +// name: String, +// bytes_scanned: usize, +// } + +#[derive(Clone, Debug)] +pub struct ExecutionStats { + // bytes_scanned: usize, + // exec_metrics: Vec, +} + +// impl ExecutionStats { +// pub fn bytes_scanned(&self) -> usize { +// self.bytes_scanned +// } +// } + +#[derive(Default)] +struct PlanVisitor { + total_bytes_scanned: usize, + // exec_metrics: Vec, +} + +impl From for ExecutionStats { + fn from(value: PlanVisitor) -> Self { + Self { + // bytes_scanned: value.total_bytes_scanned, + } + } +} + +impl ExecutionPlanVisitor for PlanVisitor { + type Error = datafusion_common::DataFusionError; + + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + match plan.metrics() { + Some(metrics) => match metrics.sum_by_name("bytes_scanned") { + Some(bytes_scanned) => { + info!("Adding {} to total_bytes_scanned", bytes_scanned.as_usize()); + self.total_bytes_scanned += bytes_scanned.as_usize(); + } + None => { + info!("No bytes_scanned for {}", plan.name()) + } + }, + None => { + info!("No MetricsSet for {}", plan.name()) + } + } + Ok(true) + } +} + +pub fn collect_plan_stats(plan: Arc) -> Option { + let mut visitor = PlanVisitor::default(); + if visit_execution_plan(plan.as_ref(), &mut visitor).is_ok() { + Some(visitor.into()) + } else { + None + } } diff --git a/src/app/handlers/flightsql.rs b/src/app/handlers/flightsql.rs index a759e5d..a6b4733 100644 --- a/src/app/handlers/flightsql.rs +++ b/src/app/handlers/flightsql.rs @@ -25,6 +25,7 @@ use tokio_stream::StreamExt; use tonic::IntoRequest; use crate::app::state::tabs::flightsql::FlightSQLQuery; +use crate::app::state::tabs::history::Context; use crate::app::{handlers::tab_navigation_handler, AppEvent}; use super::App; @@ -65,75 +66,125 @@ pub fn normal_mode_handler(app: &mut App, key: KeyEvent) { s.select_previous(); } } - KeyCode::Enter => { info!("Run FS query"); - let sql = app.state.flightsql_tab.editor().lines().join(""); - info!("SQL: {}", sql); + let full_text = app.state.flightsql_tab.editor().lines().join(""); let execution = Arc::clone(&app.execution); let _event_tx = app.event_tx(); tokio::spawn(async move { - let client = execution.flightsql_client(); - let mut query = - FlightSQLQuery::new(sql.clone(), None, None, None, Duration::default(), None); - let start = Instant::now(); - if let Some(ref mut c) = *client.lock().await { - info!("Sending query"); - match c.execute(sql, None).await { - Ok(flight_info) => { - for endpoint in flight_info.endpoint { - if let Some(ticket) = endpoint.ticket { - match c.do_get(ticket.into_request()).await { - Ok(mut stream) => { - let mut batches: Vec = Vec::new(); - // temporarily only show the first batch to avoid - // buffering massive result sets. Eventually there should - // be some sort of paging logic - // see https://github.com/datafusion-contrib/datafusion-tui/pull/133#discussion_r1756680874 - // while let Some(maybe_batch) = stream.next().await { - if let Some(maybe_batch) = stream.next().await { - match maybe_batch { - Ok(batch) => { - info!("Batch rows: {}", batch.num_rows()); - batches.push(batch); - } - Err(e) => { - error!("Error getting batch: {:?}", e); - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); - } - } - } - let elapsed = start.elapsed(); - let rows: usize = - batches.iter().map(|r| r.num_rows()).sum(); - query.set_results(Some(batches)); - query.set_num_rows(Some(rows)); - query.set_execution_time(elapsed); - } - Err(e) => { - error!("Error getting response: {:?}", e); - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); - } - } - } - } - } - Err(e) => { - error!("Error getting response: {:?}", e); - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); - } - } - } - - let _ = _event_tx.send(AppEvent::FlightSQLQueryResult(query)); + let sqls = full_text.split(';').collect(); + execution.run_flightsqls(sqls, _event_tx).await; + // let client = execution.flightsql_client(); + // let mut query = + // FlightSQLQuery::new(sql.clone(), None, None, None, Duration::default(), None); + // let start = Instant::now(); + // if let Some(ref mut c) = *client.lock().await { + // match c.execute(sql, None).await { + // Ok(flight_info) => { + // for endpoint in flight_info.endpoint { + // if let Some(ticket) = endpoint.ticket { + // match c.do_get(ticket.into_request()).await { + // Ok(mut stream) => { + // execution.set_flight_result_stream(stream).await; + // exe + // } + // Err(e) => {} + // } + // } + // } + // } + // Err(e) => {} + // } + // } }); } + KeyCode::Right => { + if let Some(p) = app + .state + .history_tab + .history() + .iter() + .filter(|q| *q.context() == Context::FlightSQL) + .last() + { + let execution = Arc::clone(&app.execution); + let sql = p.sql().clone(); + let _event_tx = app.event_tx().clone(); + app.state.flightsql_tab.next_results_page(); + // tokio::spawn(async move { + // execution.flightsql_next_page().await; + // }); + } + } + + // KeyCode::Enter => { + // info!("Run FS query"); + // let sql = app.state.flightsql_tab.editor().lines().join(""); + // info!("SQL: {}", sql); + // let execution = Arc::clone(&app.execution); + // let _event_tx = app.event_tx(); + // tokio::spawn(async move { + // let client = execution.flightsql_client(); + // let mut query = + // FlightSQLQuery::new(sql.clone(), None, None, None, Duration::default(), None); + // let start = Instant::now(); + // if let Some(ref mut c) = *client.lock().await { + // info!("Sending query"); + // match c.execute(sql, None).await { + // Ok(flight_info) => { + // for endpoint in flight_info.endpoint { + // if let Some(ticket) = endpoint.ticket { + // match c.do_get(ticket.into_request()).await { + // Ok(mut stream) => { + // let mut batches: Vec = Vec::new(); + // // temporarily only show the first batch to avoid + // // buffering massive result sets. Eventually there should + // // be some sort of paging logic + // // see https://github.com/datafusion-contrib/datafusion-tui/pull/133#discussion_r1756680874 + // // while let Some(maybe_batch) = stream.next().await { + // if let Some(maybe_batch) = stream.next().await { + // match maybe_batch { + // Ok(batch) => { + // info!("Batch rows: {}", batch.num_rows()); + // batches.push(batch); + // } + // Err(e) => { + // error!("Error getting batch: {:?}", e); + // let elapsed = start.elapsed(); + // query.set_error(Some(e.to_string())); + // query.set_execution_time(elapsed); + // } + // } + // } + // let elapsed = start.elapsed(); + // let rows: usize = + // batches.iter().map(|r| r.num_rows()).sum(); + // query.set_results(Some(batches)); + // query.set_num_rows(Some(rows)); + // query.set_execution_time(elapsed); + // } + // Err(e) => { + // error!("Error getting response: {:?}", e); + // let elapsed = start.elapsed(); + // query.set_error(Some(e.to_string())); + // query.set_execution_time(elapsed); + // } + // } + // } + // } + // } + // Err(e) => { + // error!("Error getting response: {:?}", e); + // let elapsed = start.elapsed(); + // query.set_error(Some(e.to_string())); + // query.set_execution_time(elapsed); + // } + // } + // } + // + // let _ = _event_tx.send(AppEvent::FlightSQLQueryResult(query)); + // }); + // } _ => {} } } @@ -154,11 +205,6 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) { true => editable_handler(app, key), false => normal_mode_handler(app, key), }, - AppEvent::FlightSQLQueryResult(r) => { - info!("Query results: {:?}", r); - app.state.flightsql_tab.set_query(r); - app.state.flightsql_tab.refresh_query_results_state(); - } AppEvent::Error => {} _ => {} }; diff --git a/src/app/handlers/mod.rs b/src/app/handlers/mod.rs index 62e825b..5036d6a 100644 --- a/src/app/handlers/mod.rs +++ b/src/app/handlers/mod.rs @@ -25,6 +25,7 @@ use ratatui::crossterm::event::{self, KeyCode, KeyEvent}; use tui_logger::TuiWidgetEvent; use crate::app::state::tabs::history::Context; +use crate::app::ExecutionResultsBatch; #[cfg(feature = "flightsql")] use arrow_flight::sql::client::FlightSqlServiceClient; @@ -148,8 +149,6 @@ fn context_tab_app_event_handler(app: &mut App, event: AppEvent) { } pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> { - // TODO: AppEvent::QueryResult can probably be handled here rather than duplicating in - // each tab trace!("Tui::Event: {:?}", event); let now = std::time::Instant::now(); match event { @@ -180,17 +179,35 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> { } }); } - AppEvent::QueryResult(r) => { - app.state.sql_tab.set_query(r.clone()); - app.state.sql_tab.refresh_query_results_state(); + AppEvent::NewExecution => { + app.state.sql_tab.reset_execution_results(); + } + AppEvent::ExecutionResultsError(e) => { + app.state.sql_tab.set_execution_error(e.clone()); let history_query = HistoryQuery::new( Context::Local, - r.sql().clone(), - *r.execution_time(), - r.execution_stats().clone(), + e.query().to_string(), + *e.duration(), + None, + Some(e.error().to_string()), ); + info!("Adding to history: {:?}", history_query); app.state.history_tab.add_to_history(history_query); - app.state.history_tab.refresh_history_table_state() + app.state.history_tab.refresh_history_table_state(); + } + AppEvent::ExecutionResultsNextPage(r) => { + let ExecutionResultsBatch { + query, + duration, + batch, + } = r; + app.state.sql_tab.add_batch(batch); + app.state.sql_tab.next_page(); + app.state.sql_tab.refresh_query_results_state(); + let history_query = + HistoryQuery::new(Context::Local, query.to_string(), duration, None, None); + app.state.history_tab.add_to_history(history_query); + app.state.history_tab.refresh_history_table_state(); } #[cfg(feature = "flightsql")] AppEvent::FlightSQLQueryResult(r) => { @@ -201,6 +218,7 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> { r.sql().clone(), *r.execution_time(), r.execution_stats().clone(), + None, ); app.state.history_tab.add_to_history(history_query); app.state.history_tab.refresh_history_table_state() diff --git a/src/app/handlers/sql.rs b/src/app/handlers/sql.rs index 7481a05..b667add 100644 --- a/src/app/handlers/sql.rs +++ b/src/app/handlers/sql.rs @@ -15,17 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::{sync::Arc, time::Instant}; +use std::sync::Arc; -use datafusion::{arrow::array::RecordBatch, physical_plan::execute_stream}; -use log::{error, info}; +use log::info; use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; -use tokio_stream::StreamExt; use super::App; -use crate::app::app_execution::AppExecution; -use crate::app::{handlers::tab_navigation_handler, state::tabs::sql::Query, AppEvent}; -use crate::execution::collect_plan_stats; +use crate::app::{handlers::tab_navigation_handler, AppEvent}; pub fn normal_mode_handler(app: &mut App, key: KeyEvent) { match key.code { @@ -62,19 +58,37 @@ pub fn normal_mode_handler(app: &mut App, key: KeyEvent) { } KeyCode::Enter => { - info!("Run query"); let sql = app.state.sql_tab.editor().lines().join(""); - info!("SQL: {}", sql); - let app_execution = AppExecution::new(Arc::clone(&app.execution)); + info!("Running query: {}", sql); let _event_tx = app.event_tx().clone(); - // TODO: Maybe this should be on a separate runtime to prevent blocking main thread / - // runtime - // TODO: Extract this into function to be used in both normal and editable handler + let execution = Arc::clone(&app.execution); + // TODO: Extract this into function to be used in both normal and editable handler. + // Only useful if we get Ctrl / Cmd + Enter to work in editable mode though. tokio::spawn(async move { let sqls: Vec<&str> = sql.split(';').collect(); - let _ = app_execution.run_sqls(sqls, _event_tx).await; + let _ = execution.run_sqls(sqls, _event_tx).await; }); } + KeyCode::Right => { + let _event_tx = app.event_tx().clone(); + // This won't work if you paginate the results, switch to FlightSQL tab, and then + // switch back to SQL tab, and paginate again. + // + // Need to decide if switching tabs should reset pagination. + if let Some(p) = app.state.history_tab.history().last() { + let execution = Arc::clone(&app.execution); + let sql = p.sql().clone(); + tokio::spawn(async move { + // TODO: Should be a call to `next_page` and `next_batch` is implementation + // detail. + execution.next_batch(sql, _event_tx).await; + }); + } + } + KeyCode::Left => { + app.state.sql_tab.previous_page(); + app.state.sql_tab.refresh_query_results_state(); + } _ => {} } } @@ -85,54 +99,6 @@ pub fn editable_handler(app: &mut App, key: KeyEvent) { (KeyCode::Right, KeyModifiers::ALT) => app.state.sql_tab.next_word(), (KeyCode::Backspace, KeyModifiers::ALT) => app.state.sql_tab.delete_word(), (KeyCode::Esc, _) => app.state.sql_tab.exit_edit(), - (KeyCode::Enter, KeyModifiers::CONTROL) => { - let query = app.state.sql_tab.editor().lines().join(""); - let ctx = app.execution.session_ctx().clone(); - let _event_tx = app.event_tx(); - // TODO: Maybe this should be on a separate runtime to prevent blocking main thread / - // runtime - tokio::spawn(async move { - // TODO: Turn this into a match and return the error somehow - let start = Instant::now(); - if let Ok(df) = ctx.sql(&query).await { - let plan = df.create_physical_plan().await; - match plan { - Ok(p) => { - let task_ctx = ctx.task_ctx(); - let stream = execute_stream(Arc::clone(&p), task_ctx); - let mut batches: Vec = Vec::new(); - match stream { - Ok(mut s) => { - while let Some(b) = s.next().await { - match b { - Ok(b) => batches.push(b), - Err(e) => { - error!("Error getting RecordBatch: {:?}", e) - } - } - } - - let elapsed = start.elapsed(); - let stats = collect_plan_stats(p); - info!("Got stats: {:?}", stats); - let query = - Query::new(query, Some(batches), None, None, elapsed, None); - let _ = _event_tx.send(AppEvent::QueryResult(query)); - } - Err(e) => { - error!("Error creating RecordBatchStream: {:?}", e) - } - } - } - Err(e) => { - error!("Error creating physical plan: {:?}", e) - } - } - } else { - error!("Error creating dataframe") - } - }); - } _ => app.state.sql_tab.update_editor_content(key), } } diff --git a/src/app/mod.rs b/src/app/mod.rs index 4030a67..a7fb7af 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -23,6 +23,7 @@ pub mod ui; use color_eyre::eyre::eyre; use color_eyre::Result; use crossterm::event as ct; +use datafusion::arrow::array::RecordBatch; use futures::FutureExt; use log::{debug, error, info, trace}; use ratatui::backend::CrosstermBackend; @@ -32,12 +33,14 @@ use ratatui::crossterm::{ }; use ratatui::{prelude::*, style::palette::tailwind, widgets::*}; use std::sync::Arc; +use std::time::Duration; use strum::IntoEnumIterator; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::task::JoinHandle; use tokio_stream::StreamExt; use tokio_util::sync::CancellationToken; +use self::app_execution::AppExecution; use self::handlers::{app_event_handler, crossterm_event_handler}; use self::state::tabs::sql::Query; use crate::execution::ExecutionContext; @@ -46,6 +49,34 @@ use crate::execution::ExecutionContext; use self::state::tabs::flightsql::FlightSQLQuery; #[derive(Clone, Debug)] +pub struct ExecutionError { + query: String, + error: String, + duration: Duration, +} + +#[derive(Clone, Debug)] +pub struct ExecutionResultsBatch { + query: String, + batch: RecordBatch, + duration: Duration, +} + +impl ExecutionError { + pub fn query(&self) -> &str { + &self.query + } + + pub fn error(&self) -> &str { + &self.error + } + + pub fn duration(&self) -> &Duration { + &self.duration + } +} + +#[derive(Debug)] pub enum AppEvent { Key(event::KeyEvent), Error, @@ -59,7 +90,14 @@ pub enum AppEvent { Mouse(event::MouseEvent), Resize(u16, u16), ExecuteDDL(String), + NewExecution, QueryResult(Query), + ExecutionResultsNextPage(ExecutionResultsBatch), + ExecutionResultsPreviousPage, + ExecutionResultsError(ExecutionError), + FlightSQLExecutionResultsNextPage(ExecutionResultsBatch), + FlightSQLExecutionResultsPreviousPage, + FlightSQLExecutionResultsError(ExecutionError), #[cfg(feature = "flightsql")] EstablishFlightSQLConnection, #[cfg(feature = "flightsql")] @@ -68,7 +106,7 @@ pub enum AppEvent { pub struct App<'app> { state: state::AppState<'app>, - execution: Arc, + execution: Arc, event_tx: UnboundedSender, event_rx: UnboundedReceiver, cancellation_token: CancellationToken, @@ -80,6 +118,7 @@ impl<'app> App<'app> { let (event_tx, event_rx) = mpsc::unbounded_channel(); let cancellation_token = CancellationToken::new(); let task = tokio::spawn(async {}); + let app_execution = Arc::new(AppExecution::new(Arc::new(execution))); Self { state, @@ -87,7 +126,7 @@ impl<'app> App<'app> { event_rx, event_tx, cancellation_token, - execution: Arc::new(execution), + execution: app_execution, } } @@ -99,7 +138,7 @@ impl<'app> App<'app> { &mut self.event_rx } - pub fn execution(&self) -> Arc { + pub fn execution(&self) -> Arc { Arc::clone(&self.execution) } @@ -318,7 +357,7 @@ impl App<'_> { loop { let event = app.next().await?; - if let AppEvent::Render = event.clone() { + if let AppEvent::Render = &event { terminal.draw(|f| f.render_widget(&app, f.area()))?; }; diff --git a/src/app/state/tabs/flightsql.rs b/src/app/state/tabs/flightsql.rs index 74aee20..d748bc9 100644 --- a/src/app/state/tabs/flightsql.rs +++ b/src/app/state/tabs/flightsql.rs @@ -19,6 +19,7 @@ use core::cell::RefCell; use std::time::Duration; use datafusion::arrow::array::RecordBatch; +use datafusion::arrow::array::UInt32Array; use ratatui::crossterm::event::KeyEvent; use ratatui::style::palette::tailwind; use ratatui::style::Style; @@ -27,6 +28,7 @@ use tui_textarea::TextArea; use crate::app::state::tabs::sql; use crate::execution::ExecutionStats; +use crate::app::ExecutionError; #[derive(Clone, Debug)] pub struct FlightSQLQuery { @@ -108,6 +110,9 @@ pub struct FlightSQLTabState<'app> { editor_editable: bool, query: Option, query_results_state: Option>, + result_batches: Option>, + results_page: Option, + execution_error: Option, } impl<'app> FlightSQLTabState<'app> { @@ -123,6 +128,9 @@ impl<'app> FlightSQLTabState<'app> { editor_editable: false, query: None, query_results_state: None, + result_batches: None, + results_page: None, + execution_error: None, } } @@ -197,4 +205,22 @@ impl<'app> FlightSQLTabState<'app> { pub fn delete_word(&mut self) { self.editor.delete_word(); } + + pub fn current_page_results(&self) -> Option { + match (self.results_page, &self.result_batches) { + (Some(p), Some(b)) => Some(get_records(p, b)), + _ => None, + } + } + + pub fn next_results_page(&mut self) {} +} + +fn get_records(page: usize, batches: &[RecordBatch]) -> RecordBatch { + let start = page * 100; + let end = start + 100; + let indices = ((start as u32)..(end as u32)).collect::>(); + let indices_array = UInt32Array::from(indices); + let taken = datafusion::arrow::compute::take_record_batch(&batches[0], &indices_array).unwrap(); + taken } diff --git a/src/app/state/tabs/history.rs b/src/app/state/tabs/history.rs index 7646120..fa4c3c0 100644 --- a/src/app/state/tabs/history.rs +++ b/src/app/state/tabs/history.rs @@ -22,7 +22,7 @@ use ratatui::widgets::TableState; use crate::execution::ExecutionStats; -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Context { Local, FlightSQL, @@ -43,6 +43,7 @@ pub struct HistoryQuery { sql: String, execution_time: Duration, execution_stats: Option, + error: Option, } impl HistoryQuery { @@ -51,12 +52,14 @@ impl HistoryQuery { sql: String, execution_time: Duration, execution_stats: Option, + error: Option, ) -> Self { Self { context, sql, execution_time, execution_stats, + error, } } pub fn sql(&self) -> &String { @@ -71,13 +74,13 @@ impl HistoryQuery { &self.execution_stats } - pub fn scanned_bytes(&self) -> usize { - if let Some(stats) = &self.execution_stats { - stats.bytes_scanned() - } else { - 0 - } - } + // pub fn scanned_bytes(&self) -> usize { + // if let Some(stats) = &self.execution_stats { + // stats.bytes_scanned() + // } else { + // 0 + // } + // } pub fn context(&self) -> &Context { &self.context diff --git a/src/app/state/tabs/sql.rs b/src/app/state/tabs/sql.rs index d4cafc1..903cc36 100644 --- a/src/app/state/tabs/sql.rs +++ b/src/app/state/tabs/sql.rs @@ -26,7 +26,8 @@ use ratatui::style::{Modifier, Style}; use ratatui::widgets::TableState; use tui_textarea::TextArea; -use crate::execution::ExecutionStats; +use crate::app::app_execution::ExecutionStats; +use crate::app::ExecutionError; #[derive(Clone, Debug)] pub struct Query { @@ -129,6 +130,9 @@ pub struct SQLTabState<'app> { editor_editable: bool, query: Option, query_results_state: Option>, + result_batches: Option>, + results_page: Option, + execution_error: Option, } impl<'app> SQLTabState<'app> { @@ -144,6 +148,9 @@ impl<'app> SQLTabState<'app> { editor_editable: false, query: None, query_results_state: None, + result_batches: None, + results_page: None, + execution_error: None, } } @@ -155,6 +162,13 @@ impl<'app> SQLTabState<'app> { self.query_results_state = Some(RefCell::new(TableState::default())); } + pub fn reset_execution_results(&mut self) { + self.result_batches = None; + self.results_page = None; + self.execution_error = None; + self.refresh_query_results_state(); + } + pub fn editor(&self) -> TextArea { // TODO: Figure out how to do this without clone. Probably need logic in handler to make // updates to the Widget and then pass a ref @@ -218,4 +232,47 @@ impl<'app> SQLTabState<'app> { pub fn delete_word(&mut self) { self.editor.delete_word(); } + + pub fn add_batch(&mut self, batch: RecordBatch) { + if let Some(batches) = self.result_batches.as_mut() { + batches.push(batch); + } else { + self.result_batches = Some(vec![batch]); + } + } + + pub fn current_batch(&self) -> Option<&RecordBatch> { + match (self.results_page, self.result_batches.as_ref()) { + (Some(page), Some(batches)) => batches.get(page), + _ => None, + } + } + + pub fn execution_error(&self) -> &Option { + &self.execution_error + } + + pub fn set_execution_error(&mut self, error: ExecutionError) { + self.execution_error = Some(error); + } + + pub fn results_page(&self) -> Option { + self.results_page + } + + pub fn next_page(&mut self) { + if let Some(page) = self.results_page { + self.results_page = Some(page + 1); + } else { + self.results_page = Some(0); + } + } + + pub fn previous_page(&mut self) { + if let Some(page) = self.results_page { + if page > 0 { + self.results_page = Some(page - 1); + } + } + } } diff --git a/src/app/ui/convert.rs b/src/app/ui/convert.rs index 1d129df..d3a01a2 100644 --- a/src/app/ui/convert.rs +++ b/src/app/ui/convert.rs @@ -166,7 +166,7 @@ pub fn empty_results_table<'frame>() -> Table<'frame> { } pub fn record_batches_to_table<'frame, 'results>( - record_batches: &'results [RecordBatch], + record_batches: &'results [&RecordBatch], ) -> Result> where // The results come from sql_tab state which persists until the next query is run which is diff --git a/src/app/ui/tabs/flightsql.rs b/src/app/ui/tabs/flightsql.rs index daf9ba5..10968a4 100644 --- a/src/app/ui/tabs/flightsql.rs +++ b/src/app/ui/tabs/flightsql.rs @@ -43,7 +43,43 @@ pub fn render_sql_editor(area: Rect, buf: &mut Buffer, app: &App) { } pub fn render_sql_results(area: Rect, buf: &mut Buffer, app: &App) { - let block = Block::default().title(" Results ").borders(Borders::ALL); + let flightsql_tab = &app.state.flightsql_tab; + match ( + flightsql_tab.query(), + flightsql_tab.current_page_results(), + flightsql_tab.query_results_state(), + ) { + (None, _, _) => { + let block = Block::default().title(" Results ").borders(Borders::ALL); + let row = Row::new(vec!["Run a query to generate results"]); + let widths = vec![Constraint::Percentage(100)]; + let table = Table::new(vec![row], widths).block(block); + Widget::render(table, area, buf); + } + (Some(_), Some(b), Some(s)) => { + let block = Block::default().title(" Results ").borders(Borders::ALL); + let batches = vec![&b]; + let maybe_table = record_batches_to_table(&batches); + let block = block.title_bottom(" Stats "); + match maybe_table { + Ok(table) => { + let table = table + .highlight_style(Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK)) + .block(block); + + let mut s = s.borrow_mut(); + StatefulWidget::render(table, area, buf, &mut s); + } + Err(e) => { + let row = Row::new(vec![e.to_string()]); + let widths = vec![Constraint::Percentage(100)]; + let table = Table::new(vec![row], widths).block(block); + Widget::render(table, area, buf); + } + } + } + _ => {} + } if let Some(q) = app.state.flightsql_tab.query() { if let Some(r) = q.results() { if let Some(s) = app.state.flightsql_tab.query_results_state() { @@ -54,25 +90,25 @@ pub fn render_sql_results(area: Rect, buf: &mut Buffer, app: &App) { )) .fg(tailwind::WHITE); let block = block.title_bottom(stats).fg(tailwind::ORANGE.c500); - let maybe_table = record_batches_to_table(r); - match maybe_table { - Ok(table) => { - let table = table - .highlight_style( - Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK), - ) - .block(block); - - let mut s = s.borrow_mut(); - StatefulWidget::render(table, area, buf, &mut s); - } - Err(e) => { - let row = Row::new(vec![e.to_string()]); - let widths = vec![Constraint::Percentage(100)]; - let table = Table::new(vec![row], widths).block(block); - Widget::render(table, area, buf); - } - } + // let maybe_table = record_batches_to_table(r); + // match maybe_table { + // Ok(table) => { + // let table = table + // .highlight_style( + // Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK), + // ) + // .block(block); + // + // let mut s = s.borrow_mut(); + // StatefulWidget::render(table, area, buf, &mut s); + // } + // Err(e) => { + // let row = Row::new(vec![e.to_string()]); + // let widths = vec![Constraint::Percentage(100)]; + // let table = Table::new(vec![row], widths).block(block); + // Widget::render(table, area, buf); + // } + // } } } else if let Some(e) = q.error() { let row = Row::new(vec![e.to_string()]); diff --git a/src/app/ui/tabs/history.rs b/src/app/ui/tabs/history.rs index 3f0988b..a9ee3aa 100644 --- a/src/app/ui/tabs/history.rs +++ b/src/app/ui/tabs/history.rs @@ -62,7 +62,9 @@ pub fn render_query_history(area: Rect, buf: &mut Buffer, app: &App) { .title(" Query History ") .borders(Borders::ALL); let history = app.state.history_tab.history(); + info!("History: {:?}", history); let history_table_state = app.state.history_tab.history_table_state(); + info!("History Table State: {:?}", history_table_state); match (history.is_empty(), history_table_state) { (true, _) | (_, None) => { let row = Row::new(vec!["Your query history will show here"]); @@ -85,7 +87,14 @@ pub fn render_query_history(area: Rect, buf: &mut Buffer, app: &App) { Cell::from(q.context().as_str()), Cell::from(q.sql().as_str()), Cell::from(q.execution_time().as_millis().to_string()), - Cell::from(q.scanned_bytes().to_string()), + // Not sure showing scanned_bytes is useful anymore in the context of + // paginated queries. Hard coding to zero for now but this will need to be + // revisted. One option I have is removing these type of stats from the + // query history table (so we only show execution time) and then + // _anything_ ExecutionPlan statistics related is shown in the lower pane + // and their is a `analyze` mode that runs the query to completion and + // collects all stats to show in a table next to the query. + Cell::from(0.to_string()), ]) }) .collect(); @@ -96,8 +105,9 @@ pub fn render_query_history(area: Rect, buf: &mut Buffer, app: &App) { Cell::from("Execution Time(ms)"), Cell::from("Scanned Bytes"), ]) - .bg(tailwind::WHITE) - .fg(tailwind::BLACK); + .bg(tailwind::ORANGE.c300) + .fg(tailwind::BLACK) + .bold(); let table = Table::new(rows, widths).header(header).block(block.clone()); let table = table diff --git a/src/app/ui/tabs/sql.rs b/src/app/ui/tabs/sql.rs index 16ee0cb..10f656d 100644 --- a/src/app/ui/tabs/sql.rs +++ b/src/app/ui/tabs/sql.rs @@ -20,7 +20,7 @@ use ratatui::{ layout::{Alignment, Constraint, Direction, Layout, Rect}, style::{palette::tailwind, Style, Stylize}, text::Span, - widgets::{Block, Borders, Paragraph, Row, StatefulWidget, Table, Widget}, + widgets::{block::Title, Block, Borders, Paragraph, Row, StatefulWidget, Table, Widget}, }; use crate::app::ui::convert::record_batches_to_table; @@ -44,48 +44,62 @@ pub fn render_sql_editor(area: Rect, buf: &mut Buffer, app: &App) { } pub fn render_sql_results(area: Rect, buf: &mut Buffer, app: &App) { - let block = Block::default().title(" Results ").borders(Borders::ALL); - if let Some(q) = app.state.sql_tab.query() { - if let Some(r) = q.results() { - if let Some(s) = app.state.sql_tab.query_results_state() { - let stats = Span::from(format!( - " {} rows in {}ms ", - q.num_rows().unwrap_or(0), - q.execution_time().as_millis() - )) - .fg(tailwind::WHITE); - let block = block.title_bottom(stats).fg(tailwind::ORANGE.c500); - let maybe_table = record_batches_to_table(r); - match maybe_table { - Ok(table) => { - let table = table - .highlight_style( - Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK), - ) - .block(block); + // TODO: Change this to a match on state and batch + let sql_tab = &app.state.sql_tab; + match ( + sql_tab.current_batch(), + sql_tab.results_page(), + sql_tab.query_results_state(), + sql_tab.execution_error(), + ) { + (Some(batch), Some(p), Some(s), None) => { + let block = Block::default() + .title(" Results ") + .borders(Borders::ALL) + .title(Title::from(format!(" Page {p} ")).alignment(Alignment::Right)); + let batches = vec![batch]; + let maybe_table = record_batches_to_table(&batches); - let mut s = s.borrow_mut(); - StatefulWidget::render(table, area, buf, &mut s); - } - Err(e) => { - let row = Row::new(vec![e.to_string()]); - let widths = vec![Constraint::Percentage(100)]; - let table = Table::new(vec![row], widths).block(block); - Widget::render(table, area, buf); - } + let block = block.title_bottom("Stats").fg(tailwind::ORANGE.c500); + match maybe_table { + Ok(table) => { + let table = table + .highlight_style(Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK)) + .block(block); + + let mut s = s.borrow_mut(); + StatefulWidget::render(table, area, buf, &mut s); + } + Err(e) => { + let row = Row::new(vec![e.to_string()]); + let widths = vec![Constraint::Percentage(100)]; + let table = Table::new(vec![row], widths).block(block); + Widget::render(table, area, buf); } } - } else if let Some(e) = q.error() { - let row = Row::new(vec![e.to_string()]); + } + (_, _, _, Some(e)) => { + let dur = e.duration().as_millis(); + let block = Block::default() + .title(" Results ") + .borders(Borders::ALL) + .title(Title::from(" Page ").alignment(Alignment::Right)) + .title_bottom(format!(" {}ms ", dur)); + let row = Row::new(vec![e.error().to_string()]); + let widths = vec![Constraint::Percentage(100)]; + let table = Table::new(vec![row], widths).block(block); + Widget::render(table, area, buf); + } + _ => { + let block = Block::default() + .title(" Results ") + .borders(Borders::ALL) + .title(Title::from(" Page ").alignment(Alignment::Right)); + let row = Row::new(vec!["Run a query to generate results"]); let widths = vec![Constraint::Percentage(100)]; let table = Table::new(vec![row], widths).block(block); Widget::render(table, area, buf); } - } else { - let row = Row::new(vec!["Run a query to generate results"]); - let widths = vec![Constraint::Percentage(100)]; - let table = Table::new(vec![row], widths).block(block); - Widget::render(table, area, buf); } } diff --git a/src/execution/mod.rs b/src/execution/mod.rs index d0c6b5e..9db6477 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -19,10 +19,13 @@ //! mod stats; +use std::sync::Arc; + pub use stats::{collect_plan_stats, ExecutionStats}; use color_eyre::eyre::Result; use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use datafusion::sql::parser::Statement; use tokio_stream::StreamExt; @@ -54,6 +57,12 @@ pub struct ExecutionContext { flightsql_client: Mutex>>, } +impl std::fmt::Debug for ExecutionContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExecutionContext").finish() + } +} + impl ExecutionContext { /// Construct a new `ExecutionContext` with the specified configuration pub fn try_new(config: &ExecutionConfig) -> Result { @@ -107,6 +116,16 @@ impl ExecutionContext { Ok(()) } + /// Create a physical plan from the specified SQL string. This is useful if you want to store + /// the plan and collect metrics from it. + pub async fn create_physical_plan( + &self, + sql: &str, + ) -> datafusion::error::Result> { + let df = self.session_ctx.sql(sql).await?; + df.create_physical_plan().await + } + /// Executes the specified sql string, returning the resulting /// [`SendableRecordBatchStream`] of results pub async fn execute_sql( diff --git a/src/extensions/builder.rs b/src/extensions/builder.rs index ccad3c6..69953c3 100644 --- a/src/extensions/builder.rs +++ b/src/extensions/builder.rs @@ -71,7 +71,7 @@ impl DftSessionStateBuilder { pub fn new() -> Self { let session_config = SessionConfig::default() // TODO why is batch size 1? - .with_batch_size(1) + .with_batch_size(100) .with_information_schema(true); Self {