Skip to content

Commit

Permalink
partitioned execution
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Nov 4, 2024
1 parent 973bb25 commit 2544622
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 152 deletions.
59 changes: 56 additions & 3 deletions benchmarks/src/bin/cache_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,72 @@
// under the License.

use arrow::util::pretty;
use datafusion::error::Result;
use datafusion::physical_plan::collect;
use datafusion::physical_plan::display::DisplayableExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::exec_datafusion_err;
use datafusion_flight_table::sql::{FlightSqlDriver, USERNAME};
use datafusion_flight_table::FlightTableFactory;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use structopt::StructOpt;

#[derive(Debug, StructOpt)]
struct Options {
#[structopt(long)]
queries_path: PathBuf,

#[structopt(long)]
query: Option<usize>,
}

struct AllQueries {
queries: Vec<String>,
}

impl AllQueries {
fn try_new(path: &Path) -> Result<Self> {
// ClickBench has all queries in a single file identified by line number
let all_queries = std::fs::read_to_string(path)
.map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?;
Ok(Self {
queries: all_queries.lines().map(|s| s.to_string()).collect(),
})
}

/// Returns the text of query `query_id`
fn get_query(&self, query_id: usize) -> Result<&str> {
self.queries
.get(query_id)
.ok_or_else(|| {
let min_id = self.min_query_id();
let max_id = self.max_query_id();
exec_datafusion_err!(
"Invalid query id {query_id}. Must be between {min_id} and {max_id}"
)
})
.map(|s| s.as_str())
}

fn min_query_id(&self) -> usize {
0
}

fn max_query_id(&self) -> usize {
self.queries.len() - 1
}
}

#[tokio::main]
async fn main() -> datafusion::common::Result<()> {
let options = Options::from_args();
let all_queries = AllQueries::try_new(options.queries_path.as_path())?;

let query_id = options.query.unwrap_or(0);
let sql = all_queries.get_query(query_id)?;

let ctx = SessionContext::new();
let mut state = ctx.state();
state
Expand All @@ -45,9 +101,6 @@ async fn main() -> datafusion::common::Result<()> {
.await?;
ctx.register_table("hits", Arc::new(table))?;

// let sql = r#"SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0"#;
let sql = r#"SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC"#;

let plan = ctx.sql(sql).await?;
let (state, plan) = plan.into_parts();
let plan = state.optimize(&plan)?;
Expand Down
197 changes: 48 additions & 149 deletions benchmarks/src/bin/cache_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{ArrayRef, StringArray};
use arrow::ipc::writer::IpcWriteOptions;
use arrow::record_batch::RecordBatch;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::flight_descriptor::DescriptorType;
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
Expand All @@ -32,12 +30,12 @@ use arrow_flight::{
Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest,
HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket,
};
use arrow_schema::{DataType, Field, Schema};
use dashmap::DashMap;
use datafusion::logical_expr::LogicalPlan;
use datafusion::physical_plan::collect;
use datafusion::prelude::{DataFrame, ParquetReadOptions, SessionConfig, SessionContext};
use futures::{Stream, StreamExt, TryStreamExt};
use datafusion::physical_plan::display::DisplayableExecutionPlan;
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext};
use futures::{Stream, TryStreamExt};
use log::info;
use mimalloc::MiMalloc;
use prost::Message;
Expand Down Expand Up @@ -109,7 +107,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
pub struct FlightSqlServiceImpl {
contexts: Arc<DashMap<String, Arc<SessionContext>>>,
statements: Arc<DashMap<String, LogicalPlan>>,
results: Arc<DashMap<String, Vec<RecordBatch>>>,
results: Arc<DashMap<String, Arc<dyn ExecutionPlan>>>,
default_ctx: Arc<SessionContext>,
table_name: String,
table_path: String,
Expand Down Expand Up @@ -176,15 +174,7 @@ impl FlightSqlServiceImpl {
}
}

fn get_plan(&self, handle: &str) -> Result<LogicalPlan, Status> {
if let Some(plan) = self.statements.get(handle) {
Ok(plan.clone())
} else {
Err(Status::internal(format!("Plan handle not found: {handle}")))?
}
}

fn get_result(&self, handle: &str) -> Result<Vec<RecordBatch>, Status> {
fn get_result(&self, handle: &str) -> Result<Arc<dyn ExecutionPlan>, Status> {
if let Some(result) = self.results.get(handle) {
Ok(result.clone())
} else {
Expand All @@ -194,43 +184,6 @@ impl FlightSqlServiceImpl {
}
}

async fn tables(&self, ctx: Arc<SessionContext>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("catalog_name", DataType::Utf8, true),
Field::new("db_schema_name", DataType::Utf8, true),
Field::new("table_name", DataType::Utf8, false),
Field::new("table_type", DataType::Utf8, false),
]));

let mut catalogs = vec![];
let mut schemas = vec![];
let mut names = vec![];
let mut types = vec![];
for catalog in ctx.catalog_names() {
let catalog_provider = ctx.catalog(&catalog).unwrap();
for schema in catalog_provider.schema_names() {
let schema_provider = catalog_provider.schema(&schema).unwrap();
for table in schema_provider.table_names() {
let table_provider =
schema_provider.table(&table).await.unwrap().unwrap();
catalogs.push(catalog.clone());
schemas.push(schema.clone());
names.push(table.clone());
types.push(table_provider.table_type().to_string())
}
}
}

RecordBatch::try_new(
schema,
[catalogs, schemas, names, types]
.into_iter()
.map(|i| Arc::new(StringArray::from(i)) as ArrayRef)
.collect::<Vec<_>>(),
)
.unwrap()
}

fn remove_plan(&self, handle: &str) -> Result<(), Status> {
self.statements.remove(&handle.to_string());
Ok(())
Expand Down Expand Up @@ -295,7 +248,7 @@ impl FlightSqlService for FlightSqlServiceImpl {

async fn do_get_fallback(
&self,
_request: Request<Ticket>,
request: Request<Ticket>,
message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
if !message.is::<FetchResults>() {
Expand All @@ -305,26 +258,33 @@ impl FlightSqlService for FlightSqlServiceImpl {
)))?
}

let fr: FetchResults = message
let fetch_results: FetchResults = message
.unpack()
.map_err(|e| Status::internal(format!("{e:?}")))?
.ok_or_else(|| Status::internal("Expected FetchResults but got None!"))?;

let handle = fr.handle;
let handle = fetch_results.handle;

info!("getting results for {handle}");
let result = self.get_result(&handle)?;
// if we get an empty result, create an empty schema
let (schema, batches) = match result.first() {
None => (Arc::new(Schema::empty()), vec![]),
Some(batch) => (batch.schema(), result.clone()),
};
let execution_plan = self.get_result(&handle)?;

let displayable = DisplayableExecutionPlan::new(execution_plan.as_ref());
info!("physical plan:\n{}", displayable.indent(true));

let ctx = self.get_ctx(&request)?;

let schema = execution_plan.schema();

let batch_stream = futures::stream::iter(batches).map(Ok);
let stream = execution_plan
.execute(fetch_results.partition as usize, ctx.task_ctx())
.map_err(|e| status!("Error executing plan", e))?
.map_err(|e| {
arrow_flight::error::FlightError::from_external_error(Box::new(e))
});

let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(batch_stream)
.build(stream)
.map_err(Status::from);

Ok(Response::new(Box::pin(stream)))
Expand All @@ -347,120 +307,56 @@ impl FlightSqlService for FlightSqlServiceImpl {
.create_physical_plan(&plan)
.await
.expect("Error creating physical plan");
let result = collect(physical_plan.clone(), state.task_ctx())
.await
.expect("Error executing plan");

let schema = match result.first() {
None => Schema::empty(),
Some(batch) => (*batch.schema()).clone(),
};
let partition_count = physical_plan.output_partitioning().partition_count();

let handle = Uuid::new_v4().hyphenated().to_string();
self.results.insert(handle.clone(), result);
let schema = physical_plan.schema();

let fetch = FetchResults { handle };
let buf = fetch.as_any().encode_to_vec().into();
let ticket = Ticket { ticket: buf };
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let handle = Uuid::new_v4().hyphenated().to_string();
self.results.insert(handle.clone(), physical_plan);

let flight_desc = FlightDescriptor {
r#type: DescriptorType::Cmd.into(),
cmd: Default::default(),
path: vec![],
};

let info = FlightInfo::new()
let mut info = FlightInfo::new()
.try_with_schema(&schema)
.expect("encoding failed")
.with_endpoint(endpoint)
.with_descriptor(flight_desc);

for partition in 0..partition_count {
let fetch = FetchResults {
handle: handle.clone(),
partition: partition as u32,
};
let buf = fetch.as_any().encode_to_vec().into();
let ticket = Ticket { ticket: buf };
let endpoint = FlightEndpoint::new().with_ticket(ticket.clone());
info = info.with_endpoint(endpoint);
}

let resp = Response::new(info);
Ok(resp)
}

async fn get_flight_info_prepared_statement(
&self,
cmd: CommandPreparedStatementQuery,
request: Request<FlightDescriptor>,
_cmd: CommandPreparedStatementQuery,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
info!("get_flight_info_prepared_statement");
let handle = std::str::from_utf8(&cmd.prepared_statement_handle)
.map_err(|e| status!("Unable to parse uuid", e))?;

let ctx = self.get_ctx(&request)?;
let plan = self.get_plan(handle)?;

let state = ctx.state();
let df = DataFrame::new(state, plan);
let result = df
.collect()
.await
.map_err(|e| status!("Error executing query", e))?;

// if we get an empty result, create an empty schema
let schema = match result.first() {
None => Schema::empty(),
Some(batch) => (*batch.schema()).clone(),
};

self.results.insert(handle.to_string(), result);

// if we had multiple endpoints to connect to, we could use this Location
// but in the case of standalone DataFusion, we don't
// let loc = Location {
// uri: "grpc+tcp://127.0.0.1:50051".to_string(),
// };
let fetch = FetchResults {
handle: handle.to_string(),
};
let buf = fetch.as_any().encode_to_vec().into();
let ticket = Ticket { ticket: buf };

let info = FlightInfo::new()
// Encode the Arrow schema
.try_with_schema(&schema)
.expect("encoding failed")
.with_endpoint(FlightEndpoint::new().with_ticket(ticket))
.with_descriptor(FlightDescriptor {
r#type: DescriptorType::Cmd.into(),
cmd: Default::default(),
path: vec![],
});
let resp = Response::new(info);
Ok(resp)
panic!("not implemented");
}

async fn get_flight_info_tables(
&self,
_query: CommandGetTables,
request: Request<FlightDescriptor>,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
info!("get_flight_info_tables");
let ctx = self.get_ctx(&request)?;
let data = self.tables(ctx).await;
let schema = data.schema();

let uuid = Uuid::new_v4().hyphenated().to_string();
self.results.insert(uuid.clone(), vec![data]);

let fetch = FetchResults { handle: uuid };
let buf = fetch.as_any().encode_to_vec().into();
let ticket = Ticket { ticket: buf };

let info = FlightInfo::new()
// Encode the Arrow schema
.try_with_schema(&schema)
.expect("encoding failed")
.with_endpoint(FlightEndpoint::new().with_ticket(ticket))
.with_descriptor(FlightDescriptor {
r#type: DescriptorType::Cmd.into(),
cmd: Default::default(),
path: vec![],
});
let resp = Response::new(info);
Ok(resp)
panic!("not implemented");
}

async fn do_put_prepared_statement_update(
Expand Down Expand Up @@ -531,6 +427,9 @@ impl FlightSqlService for FlightSqlServiceImpl {
pub struct FetchResults {
#[prost(string, tag = "1")]
pub handle: ::prost::alloc::string::String,

#[prost(uint32, tag = "2")]
pub partition: u32,
}

impl ProstMessageExt for FetchResults {
Expand Down
Loading

0 comments on commit 2544622

Please sign in to comment.