Skip to content

Commit

Permalink
chore: Add SessionState to MockContextProvider just like SessionConte…
Browse files Browse the repository at this point in the history
…xtProvider (#11940)

* refac: mock context provide to match public api

* lower udaf names

* cleanup

* typos

Co-authored-by: Jay Zhan <[email protected]>

* more typos

Co-authored-by: Jay Zhan <[email protected]>

* typos

* refactor func name

---------

Co-authored-by: Jay Zhan <[email protected]>
  • Loading branch information
dharanad and jayzhan211 authored Aug 12, 2024
1 parent e66636d commit 18193e6
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 61 deletions.
40 changes: 25 additions & 15 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_functions::core::planner::CoreFunctionPlanner;
use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
use sqlparser::parser::Parser;

use crate::common::MockContextProvider;
use crate::common::{MockContextProvider, MockSessionState};

#[test]
fn roundtrip_expr() {
Expand All @@ -59,8 +59,8 @@ fn roundtrip_expr() {
let roundtrip = |table, sql: &str| -> Result<String> {
let dialect = GenericDialect {};
let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?;

let context = MockContextProvider::default().with_udaf(sum_udaf());
let state = MockSessionState::default().with_aggregate_function(sum_udaf());
let context = MockContextProvider { state };
let schema = context.get_table_source(table)?.schema();
let df_schema = DFSchema::try_from(schema.as_ref().clone())?;
let sql_to_rel = SqlToRel::new(&context);
Expand Down Expand Up @@ -156,11 +156,11 @@ fn roundtrip_statement() -> Result<()> {
let statement = Parser::new(&dialect)
.try_with_sql(query)?
.parse_statement()?;

let context = MockContextProvider::default()
.with_udaf(sum_udaf())
.with_udaf(count_udaf())
let state = MockSessionState::default()
.with_aggregate_function(sum_udaf())
.with_aggregate_function(count_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();

Expand Down Expand Up @@ -189,8 +189,10 @@ fn roundtrip_crossjoin() -> Result<()> {
.try_with_sql(query)?
.parse_statement()?;

let context = MockContextProvider::default()
let state = MockSessionState::default()
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));

let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();

Expand Down Expand Up @@ -412,10 +414,12 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
.try_with_sql(query.sql)?
.parse_statement()?;

let context = MockContextProvider::default()
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()))
.with_udaf(max_udaf())
.with_udaf(min_udaf());
let state = MockSessionState::default()
.with_aggregate_function(max_udaf())
.with_aggregate_function(min_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));

let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel
.sql_statement_to_plan(statement)
Expand Down Expand Up @@ -443,7 +447,9 @@ fn test_unnest_logical_plan() -> Result<()> {
.try_with_sql(query)?
.parse_statement()?;

let context = MockContextProvider::default();
let context = MockContextProvider {
state: MockSessionState::default(),
};
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();

Expand Down Expand Up @@ -516,7 +522,9 @@ fn test_pretty_roundtrip() -> Result<()> {

let df_schema = DFSchema::try_from(schema)?;

let context = MockContextProvider::default();
let context = MockContextProvider {
state: MockSessionState::default(),
};
let sql_to_rel = SqlToRel::new(&context);

let unparser = Unparser::default().with_pretty(true);
Expand Down Expand Up @@ -589,7 +597,9 @@ fn sql_round_trip(query: &str, expect: &str) {
.parse_statement()
.unwrap();

let context = MockContextProvider::default();
let context = MockContextProvider {
state: MockSessionState::default(),
};
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();

Expand Down
52 changes: 28 additions & 24 deletions datafusion/sql/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,40 @@ impl Display for MockCsvType {
}

#[derive(Default)]
pub(crate) struct MockContextProvider {
options: ConfigOptions,
udfs: HashMap<String, Arc<ScalarUDF>>,
udafs: HashMap<String, Arc<AggregateUDF>>,
pub(crate) struct MockSessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
expr_planners: Vec<Arc<dyn ExprPlanner>>,
pub config_options: ConfigOptions,
}

impl MockContextProvider {
// Suppressing dead code warning, as this is used in integration test crates
#[allow(dead_code)]
pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions {
&mut self.options
impl MockSessionState {
pub fn with_expr_planner(mut self, expr_planner: Arc<dyn ExprPlanner>) -> Self {
self.expr_planners.push(expr_planner);
self
}

#[allow(dead_code)]
pub(crate) fn with_udf(mut self, udf: ScalarUDF) -> Self {
self.udfs.insert(udf.name().to_string(), Arc::new(udf));
pub fn with_scalar_function(mut self, scalar_function: Arc<ScalarUDF>) -> Self {
self.scalar_functions
.insert(scalar_function.name().to_string(), scalar_function);
self
}

pub(crate) fn with_udaf(mut self, udaf: Arc<AggregateUDF>) -> Self {
pub fn with_aggregate_function(
mut self,
aggregate_function: Arc<AggregateUDF>,
) -> Self {
// TODO: change to to_string() if all the function name is converted to lowercase
self.udafs.insert(udaf.name().to_lowercase(), udaf);
self.aggregate_functions.insert(
aggregate_function.name().to_string().to_lowercase(),
aggregate_function,
);
self
}
}

pub(crate) fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
self.expr_planners.push(planner);
self
}
pub(crate) struct MockContextProvider {
pub(crate) state: MockSessionState,
}

impl ContextProvider for MockContextProvider {
Expand Down Expand Up @@ -202,11 +206,11 @@ impl ContextProvider for MockContextProvider {
}

fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.udfs.get(name).cloned()
self.state.scalar_functions.get(name).cloned()
}

fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.udafs.get(name).cloned()
self.state.aggregate_functions.get(name).cloned()
}

fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
Expand All @@ -218,7 +222,7 @@ impl ContextProvider for MockContextProvider {
}

fn options(&self) -> &ConfigOptions {
&self.options
&self.state.config_options
}

fn get_file_type(
Expand All @@ -237,19 +241,19 @@ impl ContextProvider for MockContextProvider {
}

fn udf_names(&self) -> Vec<String> {
self.udfs.keys().cloned().collect()
self.state.scalar_functions.keys().cloned().collect()
}

fn udaf_names(&self) -> Vec<String> {
self.udafs.keys().cloned().collect()
self.state.aggregate_functions.keys().cloned().collect()
}

fn udwf_names(&self) -> Vec<String> {
Vec::new()
}

fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
&self.expr_planners
&self.state.expr_planners
}
}

Expand Down
52 changes: 30 additions & 22 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use datafusion_sql::{
planner::{ParserOptions, SqlToRel},
};

use crate::common::MockSessionState;
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_aggregate::{
approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf,
Expand Down Expand Up @@ -1495,8 +1496,9 @@ fn recursive_ctes_disabled() {
select * from numbers;";

// manually setting up test here so that we can disable recursive ctes
let mut context = MockContextProvider::default();
context.options_mut().execution.enable_recursive_ctes = false;
let mut state = MockSessionState::default();
state.config_options.execution.enable_recursive_ctes = false;
let context = MockContextProvider { state };

let planner = SqlToRel::new_with_options(&context, ParserOptions::default());
let result = DFParser::parse_sql_with_dialect(sql, &GenericDialect {});
Expand Down Expand Up @@ -2727,7 +2729,8 @@ fn logical_plan_with_options(sql: &str, options: ParserOptions) -> Result<Logica
}

fn logical_plan_with_dialect(sql: &str, dialect: &dyn Dialect) -> Result<LogicalPlan> {
let context = MockContextProvider::default().with_udaf(sum_udaf());
let state = MockSessionState::default().with_aggregate_function(sum_udaf());
let context = MockContextProvider { state };
let planner = SqlToRel::new(&context);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
let mut ast = result?;
Expand All @@ -2739,39 +2742,44 @@ fn logical_plan_with_dialect_and_options(
dialect: &dyn Dialect,
options: ParserOptions,
) -> Result<LogicalPlan> {
let context = MockContextProvider::default()
.with_udf(unicode::character_length().as_ref().clone())
.with_udf(string::concat().as_ref().clone())
.with_udf(make_udf(
let state = MockSessionState::default()
.with_scalar_function(Arc::new(unicode::character_length().as_ref().clone()))
.with_scalar_function(Arc::new(string::concat().as_ref().clone()))
.with_scalar_function(Arc::new(make_udf(
"nullif",
vec![DataType::Int32, DataType::Int32],
DataType::Int32,
))
.with_udf(make_udf(
)))
.with_scalar_function(Arc::new(make_udf(
"round",
vec![DataType::Float64, DataType::Int64],
DataType::Float32,
))
.with_udf(make_udf(
)))
.with_scalar_function(Arc::new(make_udf(
"arrow_cast",
vec![DataType::Int64, DataType::Utf8],
DataType::Float64,
))
.with_udf(make_udf(
)))
.with_scalar_function(Arc::new(make_udf(
"date_trunc",
vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)],
DataType::Int32,
))
.with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64))
.with_udaf(sum_udaf())
.with_udaf(approx_median_udaf())
.with_udaf(count_udaf())
.with_udaf(avg_udaf())
.with_udaf(min_udaf())
.with_udaf(max_udaf())
.with_udaf(grouping_udaf())
)))
.with_scalar_function(Arc::new(make_udf(
"sqrt",
vec![DataType::Int64],
DataType::Int64,
)))
.with_aggregate_function(sum_udaf())
.with_aggregate_function(approx_median_udaf())
.with_aggregate_function(count_udaf())
.with_aggregate_function(avg_udaf())
.with_aggregate_function(min_udaf())
.with_aggregate_function(max_udaf())
.with_aggregate_function(grouping_udaf())
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));

let context = MockContextProvider { state };
let planner = SqlToRel::new_with_options(&context, options);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
let mut ast = result?;
Expand Down

0 comments on commit 18193e6

Please sign in to comment.