Skip to content

Commit

Permalink
Support User Defined Table Function (apache#8306)
Browse files Browse the repository at this point in the history
* Support User Defined Table Function

Signed-off-by: veeupup <[email protected]>

* fix comments

Signed-off-by: veeupup <[email protected]>

* add udtf test

Signed-off-by: veeupup <[email protected]>

* add file header

* Simply table function example, add some comments

* Simplfy exprs

* make clippy happy

* Update datafusion/core/tests/user_defined/user_defined_table_functions.rs

---------

Signed-off-by: veeupup <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
2 people authored and appletreeisyellow committed Dec 14, 2023
1 parent 9ed28ae commit 265d2da
Show file tree
Hide file tree
Showing 8 changed files with 550 additions and 21 deletions.
177 changes: 177 additions & 0 deletions datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::csv::reader::Format;
use arrow::csv::ReaderBuilder;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::{ExecutionProps, SessionState};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::{plan_err, DataFusionError, ScalarValue};
use datafusion_expr::{Expr, TableType};
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use std::fs::File;
use std::io::Seek;
use std::path::Path;
use std::sync::Arc;

// To define your own table function, you only need to do the following 3 things:
// 1. Implement your own [`TableProvider`]
// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`]
// 3. Register the function using [`SessionContext::register_udtf`]

/// This example demonstrates how to register a TableFunction
#[tokio::main]
async fn main() -> Result<()> {
// create local execution context
let ctx = SessionContext::new();

// register the table function that will be called in SQL statements by `read_csv`
ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {}));

let testdata = datafusion::test_util::arrow_test_data();
let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");

// Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2)
let df = ctx
.sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str())
.await?;
df.show().await?;

// just run, return all rows
let df = ctx
.sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str())
.await?;
df.show().await?;

Ok(())
}

/// Table Function that mimics the [`read_csv`] function in DuckDB.
///
/// Usage: `read_csv(filename, [limit])`
///
/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html
struct LocalCsvTable {
schema: SchemaRef,
limit: Option<usize>,
batches: Vec<RecordBatch>,
}

#[async_trait]
impl TableProvider for LocalCsvTable {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn schema(&self) -> SchemaRef {
self.schema.clone()
}

fn table_type(&self) -> TableType {
TableType::Base
}

async fn scan(
&self,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batches = if let Some(max_return_lines) = self.limit {
// get max return rows from self.batches
let mut batches = vec![];
let mut lines = 0;
for batch in &self.batches {
let batch_lines = batch.num_rows();
if lines + batch_lines > max_return_lines {
let batch_lines = max_return_lines - lines;
batches.push(batch.slice(0, batch_lines));
break;
} else {
batches.push(batch.clone());
lines += batch_lines;
}
}
batches
} else {
self.batches.clone()
};
Ok(Arc::new(MemoryExec::try_new(
&[batches],
TableProvider::schema(self),
projection.cloned(),
)?))
}
}
struct LocalCsvTableFunc {}

impl TableFunctionImpl for LocalCsvTableFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.get(0) else {
return plan_err!("read_csv requires at least one string argument");
};

let limit = exprs
.get(1)
.map(|expr| {
// try to simpify the expression, so 1+2 becomes 3, for example
let execution_props = ExecutionProps::new();
let info = SimplifyContext::new(&execution_props);
let expr = ExprSimplifier::new(info).simplify(expr.clone())?;

if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr {
Ok(limit as usize)
} else {
plan_err!("Limit must be an integer")
}
})
.transpose()?;

let (schema, batches) = read_csv_batches(path)?;

let table = LocalCsvTable {
schema,
limit,
batches,
};
Ok(Arc::new(table))
}
}

fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef, Vec<RecordBatch>)> {
let mut file = File::open(csv_path)?;
let (schema, _) = Format::default().infer_schema(&mut file, None)?;
file.rewind()?;

let reader = ReaderBuilder::new(Arc::new(schema.clone()))
.with_header(true)
.build(file)?;
let mut batches = vec![];
for bacth in reader {
batches.push(bacth?);
}
let schema = Arc::new(schema);
Ok((schema, batches))
}
56 changes: 56 additions & 0 deletions datafusion/core/src/datasource/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! A table that uses a function to generate data
use super::TableProvider;

use datafusion_common::Result;
use datafusion_expr::Expr;

use std::sync::Arc;

/// A trait for table function implementations
pub trait TableFunctionImpl: Sync + Send {
/// Create a table provider
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
}

/// A table that uses a function to generate data
pub struct TableFunction {
/// Name of the table function
name: String,
/// Function implementation
fun: Arc<dyn TableFunctionImpl>,
}

impl TableFunction {
/// Create a new table function
pub fn new(name: String, fun: Arc<dyn TableFunctionImpl>) -> Self {
Self { name, fun }
}

/// Get the name of the table function
pub fn name(&self) -> &str {
&self.name
}

/// Get the function implementation and generate a table
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
self.fun.call(args)
}
}
1 change: 1 addition & 0 deletions datafusion/core/src/datasource/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod avro_to_arrow;
pub mod default_table_source;
pub mod empty;
pub mod file_format;
pub mod function;
pub mod listing;
pub mod listing_table_factory;
pub mod memory;
Expand Down
30 changes: 29 additions & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mod parquet;
use crate::{
catalog::{CatalogList, MemoryCatalogList},
datasource::{
function::{TableFunction, TableFunctionImpl},
listing::{ListingOptions, ListingTable},
provider::TableProviderFactory,
},
Expand All @@ -42,7 +43,7 @@ use datafusion_common::{
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
Expand Down Expand Up @@ -803,6 +804,14 @@ impl SessionContext {
.add_var_provider(variable_type, provider);
}

/// Register a table UDF with this context
pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
self.state.write().table_functions.insert(
name.to_owned(),
Arc::new(TableFunction::new(name.to_owned(), fun)),
);
}

/// Registers a scalar UDF within this context.
///
/// Note in SQL queries, function names are looked up using
Expand Down Expand Up @@ -1241,6 +1250,8 @@ pub struct SessionState {
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
/// Collection of catalogs containing schemas and ultimately TableProviders
catalog_list: Arc<dyn CatalogList>,
/// Table Functions
table_functions: HashMap<String, Arc<TableFunction>>,
/// Scalar functions that are registered with the context
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions registered in the context
Expand Down Expand Up @@ -1339,6 +1350,7 @@ impl SessionState {
physical_optimizers: PhysicalOptimizer::new(),
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
table_functions: HashMap::new(),
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
Expand Down Expand Up @@ -1877,6 +1889,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
.ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
}

fn get_table_function_source(
&self,
name: &str,
args: Vec<Expr>,
) -> Result<Arc<dyn TableSource>> {
let tbl_func = self
.state
.table_functions
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
let provider = tbl_func.create_table_provider(&args)?;

Ok(provider_as_source(provider))
}

fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/user_defined/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ mod user_defined_plan;

/// Tests for User Defined Window Functions
mod user_defined_window_functions;

/// Tests for User Defined Table Functions
mod user_defined_table_functions;
Loading

0 comments on commit 265d2da

Please sign in to comment.