Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Function dot syntax #3369

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions crates/rayexec_execution/src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct SessionConfig {
pub partitions: u64,
pub batch_size: u64,
pub verify_optimized_plan: bool,
pub enable_function_chaining: bool,
}

impl SessionConfig {
Expand All @@ -30,6 +31,7 @@ impl SessionConfig {
partitions: executor.default_partitions() as u64,
batch_size: 4096,
verify_optimized_plan: false,
enable_function_chaining: true,
}
}

Expand Down Expand Up @@ -103,6 +105,7 @@ static GET_SET_FUNCTIONS: LazyLock<HashMap<&'static str, SettingFunctions>> = La
insert_setting::<AllowNestedLoopJoin>(&mut map);
insert_setting::<Partitions>(&mut map);
insert_setting::<BatchSize>(&mut map);
insert_setting::<EnableFunctionChaining>(&mut map);

map
});
Expand Down Expand Up @@ -218,6 +221,23 @@ impl SessionSetting for VerifyOptimizedPlan {
}
}

pub struct EnableFunctionChaining;

impl SessionSetting for EnableFunctionChaining {
const NAME: &'static str = "enable_function_chaining";
const DESCRIPTION: &'static str = "If function chaining syntax is enabled.";

fn set_from_scalar(scalar: ScalarValue, conf: &mut SessionConfig) -> Result<()> {
let val = scalar.try_as_bool()?;
conf.enable_function_chaining = val;
Ok(())
}

fn get_as_scalar(conf: &SessionConfig) -> OwnedScalarValue {
conf.enable_function_chaining.into()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -230,6 +250,7 @@ mod tests {
partitions: 8,
batch_size: 4096,
verify_optimized_plan: false,
enable_function_chaining: true,
}
}

Expand Down
5 changes: 4 additions & 1 deletion crates/rayexec_execution/src/engine/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::logical::logical_set::VariableOrAll;
use crate::logical::operator::{LogicalOperator, Node};
use crate::logical::planner::plan_statement::StatementPlanner;
use crate::logical::resolver::resolve_context::ResolveContext;
use crate::logical::resolver::{ResolveMode, ResolvedStatement, Resolver};
use crate::logical::resolver::{ResolveConfig, ResolveMode, ResolvedStatement, Resolver};
use crate::optimizer::Optimizer;
use crate::runtime::time::Timer;
use crate::runtime::{PipelineExecutor, Runtime};
Expand Down Expand Up @@ -243,6 +243,9 @@ where
&tx,
&self.context,
self.registry.get_file_handlers(),
ResolveConfig {
enable_function_chaining: self.config.enable_function_chaining,
},
)
.resolve_statement(stmt.statement.clone())
.await?;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use std::collections::HashMap;

use rayexec_error::Result;
use rayexec_error::{RayexecError, Result};
use rayexec_parser::ast;

use super::select_expr_expander::ExpandedSelectExpr;
use super::select_list::SelectList;
use crate::expr::column_expr::ColumnExpr;
use crate::expr::Expression;
use crate::logical::binder::bind_context::{BindContext, BindScopeRef};
use crate::logical::binder::column_binder::DefaultColumnBinder;
use crate::logical::binder::column_binder::{DefaultColumnBinder, ExpressionColumnBinder};
use crate::logical::binder::expr_binder::{BaseExpressionBinder, RecursionContext};
use crate::logical::binder::table_list::TableRef;
use crate::logical::resolver::resolve_context::ResolveContext;
use crate::logical::resolver::ResolvedMeta;

#[derive(Debug)]
pub struct SelectListBinder<'a> {
Expand Down Expand Up @@ -77,13 +78,19 @@ impl<'a> SelectListBinder<'a> {
// Bind the expressions.
let expr_binder = BaseExpressionBinder::new(self.current, self.resolve_context);
let mut exprs = Vec::with_capacity(projections.len());
for proj in projections {
for (idx, proj) in projections.into_iter().enumerate() {
match proj {
ExpandedSelectExpr::Expr { expr, .. } => {
let mut col_binder = SelectAliasColumnBinder {
current_idx: idx,
alias_map: &alias_map,
previous_exprs: &exprs,
};

let expr = expr_binder.bind_expression(
bind_context,
&expr,
&mut DefaultColumnBinder,
&mut col_binder,
RecursionContext {
allow_windows: true,
allow_aggregates: true,
Expand Down Expand Up @@ -242,3 +249,76 @@ impl<'a> SelectListBinder<'a> {
Ok(())
}
}

/// Column binder that allows binding to previously defined user aliases.
///
/// If an ident isn't found in the alias map, then default column binding is
/// used.
///
/// Aliases are only checked if normal column binding cannot find a column.
#[derive(Debug, Clone, Copy)]
struct SelectAliasColumnBinder<'a> {
/// Index of the expression we're currently planning in the select list.
///
/// Used to determine if an alias is valid to use.
current_idx: usize,
/// User provided aliases.
alias_map: &'a HashMap<String, usize>,
/// Previously planned expressions.
previous_exprs: &'a [Expression],
}

impl ExpressionColumnBinder for SelectAliasColumnBinder<'_> {
fn bind_from_root_literal(
&mut self,
bind_scope: BindScopeRef,
bind_context: &mut BindContext,
literal: &ast::Literal<ResolvedMeta>,
) -> Result<Option<Expression>> {
DefaultColumnBinder.bind_from_root_literal(bind_scope, bind_context, literal)
}

fn bind_from_ident(
&mut self,
bind_scope: BindScopeRef,
bind_context: &mut BindContext,
ident: &ast::Ident,
_recur: RecursionContext,
) -> Result<Option<Expression>> {
let col = ident.as_normalized_string();

match DefaultColumnBinder.bind_column(bind_scope, bind_context, None, &col)? {
Some(expr) => Ok(Some(expr)),
None => {
match self.alias_map.get(&col) {
Some(&col_idx) => {
if col_idx < self.current_idx {
// Valid alias reference, use the existing expression.
let aliased_expr =
self.previous_exprs.get(col_idx).ok_or_else(|| {
RayexecError::new("Missing select expression?")
.with_field("idx", col_idx)
})?;

Ok(Some(aliased_expr.clone()))
} else {
// Not a valid alias expression.
Err(RayexecError::new(format!("'{col}' can only be referenced after it's been defined in the SELECT list")))
}
}
None => Ok(None),
}
}
}
}

fn bind_from_idents(
&mut self,
bind_scope: BindScopeRef,
bind_context: &mut BindContext,
idents: &[ast::Ident],
recur: RecursionContext,
) -> Result<Option<Expression>> {
DefaultColumnBinder.bind_from_idents(bind_scope, bind_context, idents, recur)
}
}
6 changes: 5 additions & 1 deletion crates/rayexec_execution/src/logical/binder/column_binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ impl ExpressionColumnBinder for DefaultColumnBinder {
}

impl DefaultColumnBinder {
fn bind_column(
/// Binds a column with the given name and optional table alias.
///
/// This will handle appending correlated columns to the bind context as
/// necessary.
pub fn bind_column(
&self,
bind_scope: BindScopeRef,
bind_context: &mut BindContext,
Expand Down
102 changes: 83 additions & 19 deletions crates/rayexec_execution/src/logical/resolver/expr_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::database::catalog_entry::CatalogEntryType;
use crate::logical::binder::expr_binder::BaseExpressionBinder;
use crate::logical::operator::LocationRequirement;

#[derive(Debug)]
pub struct ExpressionResolver<'a> {
resolver: &'a Resolver<'a>,
}
Expand Down Expand Up @@ -237,7 +238,10 @@ impl<'a> ExpressionResolver<'a> {
op,
right: Box::new(Box::pin(self.resolve_expression(*right, resolve_context)).await?),
}),
ast::Expr::Function(func) => self.resolve_function(func, resolve_context).await,
ast::Expr::Function(func) => {
self.resolve_scalar_or_aggregate_function(func, resolve_context)
.await
}
ast::Expr::Subquery(subquery) => self.resolve_subquery(subquery, resolve_context).await,
ast::Expr::Exists {
subquery,
Expand Down Expand Up @@ -556,20 +560,82 @@ impl<'a> ExpressionResolver<'a> {
}
}

async fn resolve_function(
async fn resolve_scalar_or_aggregate_function(
&self,
func: Box<ast::Function<Raw>>,
mut func: Box<ast::Function<Raw>>,
resolve_context: &mut ResolveContext,
) -> Result<ast::Expr<ResolvedMeta>> {
// TODO: Search path (with system being the first to check)
if func.reference.0.len() != 1 {
return Err(RayexecError::new(
"Qualified function names not yet supported",
));
let (catalog, schema, func_name) = match func.reference.0.len() {
0 => return Err(RayexecError::new("Missing idents for function reference")), // Shouldn't happen.
1 => (
"system".to_string(),
"glare_catalog".to_string(),
func.reference.0[0].as_normalized_string(),
),
2 => (
"system".to_string(),
func.reference.0[0].as_normalized_string(),
func.reference.0[1].as_normalized_string(),
),
3 => (
func.reference.0[0].as_normalized_string(),
func.reference.0[1].as_normalized_string(),
func.reference.0[2].as_normalized_string(),
),
_ => {
// TODO: This could technically be from chained syntax on a
// fully qualified column.
return Err(RayexecError::new("Too many idents for function reference")
.with_field("idents", func.reference.to_string()));
}
};

let context = self.resolver.context;

// See if we can resolve the catalog & schema. If we can't assume we're
// using chained function syntax.
//
// TODO: Make `get_database` return Option.
// TODO: We should be exhaustive about what's part of the qualified
// function call vs what's part of the column.
let is_qualified = func.reference.0.len() > 1;
if self.resolver.config.enable_function_chaining
&& is_qualified
&& (!context.database_exists(&catalog)
|| context
.get_database(&catalog)?
.catalog
.get_schema(self.resolver.tx, &schema)?
.is_none())
{
let unqualified_name = func.reference.0.pop().unwrap(); // Length checked above.
let unqualified_ref = ast::ObjectReference(vec![unqualified_name]);

let mut prefix_ref = std::mem::replace(&mut func.reference, unqualified_ref);

// Now add the prefix we took from the reference as the first
// argument to the function.

// TODO: Expr binder should probably take of this for us.
let arg_expr = match prefix_ref.0.len() {
1 => ast::Expr::Ident(prefix_ref.0.pop().unwrap()),
_ => ast::Expr::CompoundIdent(prefix_ref.0),
};

func.args.insert(
0,
ast::FunctionArg::Unnamed {
arg: ast::FunctionArgExpr::Expr(arg_expr),
},
);

// Now try to resolve with just the unqualified reference.
let resolved =
Box::pin(self.resolve_scalar_or_aggregate_function(func, resolve_context)).await?;

return Ok(resolved);
}
let func_name = &func.reference.0[0].as_normalized_string();
let catalog = "system";
let schema = "glare_catalog";

let filter = self
.resolve_optional_expression(func.filter.map(|e| *e), resolve_context)
Expand All @@ -582,16 +648,14 @@ impl<'a> ExpressionResolver<'a> {
};
let args = Box::pin(self.resolve_function_args(func.args, resolve_context)).await?;

let schema_ent = self
.resolver
.context
.get_database(catalog)?
let schema_ent = context
.get_database(&catalog)?
.catalog
.get_schema(self.resolver.tx, schema)?
.get_schema(self.resolver.tx, &schema)?
.ok_or_else(|| RayexecError::new(format!("Missing schema: {schema}")))?;

// Check if this is a special function.
if let Some(special) = SpecialBuiltinFunction::try_from_name(func_name) {
if let Some(special) = SpecialBuiltinFunction::try_from_name(&func_name) {
let resolve_idx = resolve_context
.functions
.push_resolved(ResolvedFunction::Special(special), LocationRequirement::Any);
Expand All @@ -606,7 +670,7 @@ impl<'a> ExpressionResolver<'a> {
}

// Now check scalars.
if let Some(scalar) = schema_ent.get_scalar_function(self.resolver.tx, func_name)? {
if let Some(scalar) = schema_ent.get_scalar_function(self.resolver.tx, &func_name)? {
// TODO: Allow unresolved scalars?
// TODO: This also assumes scalars (and aggs) are the same everywhere, which
// they probably should be for now.
Expand All @@ -624,7 +688,7 @@ impl<'a> ExpressionResolver<'a> {
}

// Now check aggregates.
if let Some(aggregate) = schema_ent.get_aggregate_function(self.resolver.tx, func_name)? {
if let Some(aggregate) = schema_ent.get_aggregate_function(self.resolver.tx, &func_name)? {
// TODO: Allow unresolved aggregates?
let resolve_idx = resolve_context.functions.push_resolved(
ResolvedFunction::Aggregate(
Expand All @@ -651,7 +715,7 @@ impl<'a> ExpressionResolver<'a> {
CatalogEntryType::ScalarFunction,
CatalogEntryType::AggregateFunction,
],
func_name,
&func_name,
))
}

Expand Down
Loading
Loading