Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support User Defined Table Function #8306

Merged
merged 10 commits into from
Nov 30, 2023
177 changes: 177 additions & 0 deletions datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::csv::reader::Format;
use arrow::csv::ReaderBuilder;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::function::TableFunctionImpl;
use datafusion::datasource::TableProvider;
use datafusion::error::Result;
use datafusion::execution::context::{ExecutionProps, SessionState};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::{plan_err, DataFusionError, ScalarValue};
use datafusion_expr::{Expr, TableType};
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use std::fs::File;
use std::io::Seek;
use std::path::Path;
use std::sync::Arc;

// To define your own table function, you only need to do the following 3 things:
// 1. Implement your own [`TableProvider`]
// 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`]
// 3. Register the function using [`SessionContext::register_udtf`]

/// This example demonstrates how to register a TableFunction
#[tokio::main]
async fn main() -> Result<()> {
// create local execution context
let ctx = SessionContext::new();

// register the table function that will be called in SQL statements by `read_csv`
ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {}));

let testdata = datafusion::test_util::arrow_test_data();
let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");

// Pass 2 arguments, read csv with at most 2 rows (simplify logic makes 1+1 --> 2)
let df = ctx
.sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str())
.await?;
df.show().await?;

// just run, return all rows
let df = ctx
.sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str())
.await?;
df.show().await?;

Ok(())
}

/// Table Function that mimics the [`read_csv`] function in DuckDB.
///
/// Usage: `read_csv(filename, [limit])`
///
/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html
struct LocalCsvTable {
schema: SchemaRef,
limit: Option<usize>,
batches: Vec<RecordBatch>,
}

#[async_trait]
impl TableProvider for LocalCsvTable {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn schema(&self) -> SchemaRef {
self.schema.clone()
}

fn table_type(&self) -> TableType {
TableType::Base
}

async fn scan(
&self,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batches = if let Some(max_return_lines) = self.limit {
// get max return rows from self.batches
let mut batches = vec![];
let mut lines = 0;
for batch in &self.batches {
let batch_lines = batch.num_rows();
if lines + batch_lines > max_return_lines {
let batch_lines = max_return_lines - lines;
batches.push(batch.slice(0, batch_lines));
break;
} else {
batches.push(batch.clone());
lines += batch_lines;
}
}
batches
} else {
self.batches.clone()
};
Ok(Arc::new(MemoryExec::try_new(
&[batches],
TableProvider::schema(self),
projection.cloned(),
)?))
}
}
struct LocalCsvTableFunc {}

impl TableFunctionImpl for LocalCsvTableFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.get(0) else {
return plan_err!("read_csv requires at least one string argument");
};

let limit = exprs
.get(1)
.map(|expr| {
// try to simpify the expression, so 1+2 becomes 3, for example
let execution_props = ExecutionProps::new();
let info = SimplifyContext::new(&execution_props);
let expr = ExprSimplifier::new(info).simplify(expr.clone())?;

if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr {
Ok(limit as usize)
} else {
plan_err!("Limit must be an integer")
}
})
.transpose()?;

let (schema, batches) = read_csv_batches(path)?;

let table = LocalCsvTable {
schema,
limit,
batches,
};
Ok(Arc::new(table))
}
}

fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef, Vec<RecordBatch>)> {
let mut file = File::open(csv_path)?;
let (schema, _) = Format::default().infer_schema(&mut file, None)?;
file.rewind()?;

let reader = ReaderBuilder::new(Arc::new(schema.clone()))
.with_header(true)
.build(file)?;
let mut batches = vec![];
for bacth in reader {
batches.push(bacth?);
}
let schema = Arc::new(schema);
Ok((schema, batches))
}
56 changes: 56 additions & 0 deletions datafusion/core/src/datasource/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! A table that uses a function to generate data

use super::TableProvider;

use datafusion_common::Result;
use datafusion_expr::Expr;

use std::sync::Arc;

/// A trait for table function implementations
pub trait TableFunctionImpl: Sync + Send {
/// Create a table provider
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is nice and is as specified in #7926. I think it will work for using a table function as a relation in the query (aka like a table with parameters)

The one thing I don't think this API supports is TableFunctions that take other arguments (aka that are fed the result of a table / can use the value of correlated subqueries as mentioned by @yukkit and @Jesse-Bakker #7926 (comment).

I can think of two options:

  1. Leave this API as is , and add a follow on / new API somehow to support that usecase
  2. Try to extend this API somehow to support table inputs

I personally prefer 1 as I think it offers several additional use cases, even though it doesn't cover "take a table input".

Any other thoughts?

Copy link
Contributor

@Jesse-Bakker Jesse-Bakker Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One specific use case for table-valued arguments to table-valued functions is, for example windowing tvf's like in apache flink.

Example which cannot be expressed by taking Expr arguments (maybe if Expr::Row() is added?):

SELECT window_start, window_end, SUM(price)
  FROM TABLE(
    TUMBLE(TABLE Bid, DESCRIPTOR(bidtime), INTERVAL '10' MINUTES))
  GROUP BY window_start, window_end;

That can also be emulated, however, using something like:

SELECT window_start, window_end, SUM(price)
  FROM Bid,
    TUMBLE(Bid.bidtime, INTERVAL '10' MINUTES))
    GROUP BY window_start, window_end;

which doesn't need table-valued arguments (but does need to resolve Expr::Column(name=bidtime). I'm not sure if the current API can do that?).

Anyway, the current API is nice, and definitely very useful 👍

}

/// A table that uses a function to generate data
pub struct TableFunction {
/// Name of the table function
name: String,
/// Function implementation
fun: Arc<dyn TableFunctionImpl>,
}

impl TableFunction {
/// Create a new table function
pub fn new(name: String, fun: Arc<dyn TableFunctionImpl>) -> Self {
Self { name, fun }
}

/// Get the name of the table function
pub fn name(&self) -> &str {
&self.name
}

/// Get the function implementation and generate a table
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

self.fun.call(args)
}
}
1 change: 1 addition & 0 deletions datafusion/core/src/datasource/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod avro_to_arrow;
pub mod default_table_source;
pub mod empty;
pub mod file_format;
pub mod function;
pub mod listing;
pub mod listing_table_factory;
pub mod memory;
Expand Down
30 changes: 29 additions & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mod parquet;
use crate::{
catalog::{CatalogList, MemoryCatalogList},
datasource::{
function::{TableFunction, TableFunctionImpl},
listing::{ListingOptions, ListingTable},
provider::TableProviderFactory,
},
Expand All @@ -42,7 +43,7 @@ use datafusion_common::{
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
Expand Down Expand Up @@ -803,6 +804,14 @@ impl SessionContext {
.add_var_provider(variable_type, provider);
}

/// Register a table UDF with this context
pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
self.state.write().table_functions.insert(
name.to_owned(),
Arc::new(TableFunction::new(name.to_owned(), fun)),
);
}

/// Registers a scalar UDF within this context.
///
/// Note in SQL queries, function names are looked up using
Expand Down Expand Up @@ -1234,6 +1243,8 @@ pub struct SessionState {
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
/// Collection of catalogs containing schemas and ultimately TableProviders
catalog_list: Arc<dyn CatalogList>,
/// Table Functions
table_functions: HashMap<String, Arc<TableFunction>>,
/// Scalar functions that are registered with the context
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions registered in the context
Expand Down Expand Up @@ -1332,6 +1343,7 @@ impl SessionState {
physical_optimizers: PhysicalOptimizer::new(),
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
table_functions: HashMap::new(),
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
Expand Down Expand Up @@ -1870,6 +1882,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
.ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
}

fn get_table_function_source(
&self,
name: &str,
args: Vec<Expr>,
) -> Result<Arc<dyn TableSource>> {
let tbl_func = self
.state
.table_functions
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
let provider = tbl_func.create_table_provider(&args)?;

Ok(provider_as_source(provider))
}

fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/user_defined/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ mod user_defined_plan;

/// Tests for User Defined Window Functions
mod user_defined_window_functions;

/// Tests for User Defined Table Functions
mod user_defined_table_functions;
Loading