From afa58b69b79aa26fa764b903bb9c77212e4fdb47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Thu, 21 Nov 2024 18:01:12 +0000 Subject: [PATCH] Move BallistaRegistry to appropriate location During refactoring `BallistaRegistry` ended up at very strange location. Proposal moves registry to to more appropriate location. --- ballista/core/src/lib.rs | 1 + ballista/core/src/registry.rs | 112 ++++++++++++++++++++++ ballista/core/src/serde/scheduler/mod.rs | 103 +------------------- ballista/executor/src/executor.rs | 2 +- ballista/executor/src/executor_process.rs | 2 +- ballista/executor/src/standalone.rs | 2 +- 6 files changed, 120 insertions(+), 102 deletions(-) create mode 100644 ballista/core/src/registry.rs diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index 4341f443a..f415af70e 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -32,6 +32,7 @@ pub mod consistent_hash; pub mod error; pub mod event_loop; pub mod execution_plans; +pub mod registry; pub mod utils; #[macro_use] diff --git a/ballista/core/src/registry.rs b/ballista/core/src/registry.rs new file mode 100644 index 000000000..2f55e2809 --- /dev/null +++ b/ballista/core/src/registry.rs @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; +use datafusion::execution::{FunctionRegistry, SessionState}; +use datafusion::functions::all_default_functions; +use datafusion::functions_aggregate::all_default_aggregate_functions; +use datafusion::functions_window::all_default_window_functions; +use datafusion::logical_expr::planner::ExprPlanner; +use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +#[derive(Debug)] +pub struct BallistaFunctionRegistry { + pub scalar_functions: HashMap>, + pub aggregate_functions: HashMap>, + pub window_functions: HashMap>, +} + +impl Default for BallistaFunctionRegistry { + fn default() -> Self { + let scalar_functions = all_default_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + let aggregate_functions = all_default_aggregate_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + let window_functions = all_default_window_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + Self { + scalar_functions, + aggregate_functions, + window_functions, + } + } +} + +impl FunctionRegistry for BallistaFunctionRegistry { + fn expr_planners(&self) -> Vec> { + vec![] + } + + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } + + fn udf(&self, name: &str) -> datafusion::common::Result> { + let result = self.scalar_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDF named \"{name}\" in the TaskContext" + )) + }) + } + + fn udaf(&self, name: &str) -> datafusion::common::Result> { + let result = self.aggregate_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDAF named \"{name}\" in the TaskContext" + )) + }) + } + + fn udwf(&self, name: &str) -> datafusion::common::Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDWF named \"{name}\" in the TaskContext" + )) + }) + } +} + +impl From<&SessionState> for BallistaFunctionRegistry { + fn from(state: &SessionState) -> Self { + let scalar_functions = state.scalar_functions().clone(); + let aggregate_functions = state.aggregate_functions().clone(); + let window_functions = state.window_functions().clone(); + + Self { + scalar_functions, + aggregate_functions, + window_functions, + } + } +} diff --git a/ballista/core/src/serde/scheduler/mod.rs b/ballista/core/src/serde/scheduler/mod.rs index 2905455eb..a2c92ff8a 100644 --- a/ballista/core/src/serde/scheduler/mod.rs +++ b/ballista/core/src/serde/scheduler/mod.rs @@ -15,27 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::fmt::Debug; -use std::{collections::HashMap, fmt, sync::Arc}; - +use crate::error::BallistaError; +use crate::registry::BallistaFunctionRegistry; use datafusion::arrow::array::{ ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder, }; use datafusion::arrow::datatypes::{DataType, Field}; -use datafusion::common::DataFusionError; -use datafusion::execution::{FunctionRegistry, SessionState}; -use datafusion::functions::all_default_functions; -use datafusion::functions_aggregate::all_default_aggregate_functions; -use datafusion::functions_window::all_default_window_functions; -use datafusion::logical_expr::planner::ExprPlanner; -use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::Partitioning; use datafusion::prelude::SessionConfig; use serde::Serialize; - -use crate::error::BallistaError; +use std::fmt::Debug; +use std::{collections::HashMap, fmt, sync::Arc}; pub mod from_proto; pub mod to_proto; @@ -295,89 +286,3 @@ pub struct TaskDefinition { pub session_config: SessionConfig, pub function_registry: Arc, } - -#[derive(Debug)] -pub struct BallistaFunctionRegistry { - pub scalar_functions: HashMap>, - pub aggregate_functions: HashMap>, - pub window_functions: HashMap>, -} - -impl Default for BallistaFunctionRegistry { - fn default() -> Self { - let scalar_functions = all_default_functions() - .into_iter() - .map(|f| (f.name().to_string(), f)) - .collect(); - - let aggregate_functions = all_default_aggregate_functions() - .into_iter() - .map(|f| (f.name().to_string(), f)) - .collect(); - - let window_functions = all_default_window_functions() - .into_iter() - .map(|f| (f.name().to_string(), f)) - .collect(); - - Self { - scalar_functions, - aggregate_functions, - window_functions, - } - } -} - -impl FunctionRegistry for BallistaFunctionRegistry { - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udfs(&self) -> HashSet { - self.scalar_functions.keys().cloned().collect() - } - - fn udf(&self, name: &str) -> datafusion::common::Result> { - let result = self.scalar_functions.get(name); - - result.cloned().ok_or_else(|| { - DataFusionError::Internal(format!( - "There is no UDF named \"{name}\" in the TaskContext" - )) - }) - } - - fn udaf(&self, name: &str) -> datafusion::common::Result> { - let result = self.aggregate_functions.get(name); - - result.cloned().ok_or_else(|| { - DataFusionError::Internal(format!( - "There is no UDAF named \"{name}\" in the TaskContext" - )) - }) - } - - fn udwf(&self, name: &str) -> datafusion::common::Result> { - let result = self.window_functions.get(name); - - result.cloned().ok_or_else(|| { - DataFusionError::Internal(format!( - "There is no UDWF named \"{name}\" in the TaskContext" - )) - }) - } -} - -impl From<&SessionState> for BallistaFunctionRegistry { - fn from(state: &SessionState) -> Self { - let scalar_functions = state.scalar_functions().clone(); - let aggregate_functions = state.aggregate_functions().clone(); - let window_functions = state.window_functions().clone(); - - Self { - scalar_functions, - aggregate_functions, - window_functions, - } - } -} diff --git a/ballista/executor/src/executor.rs b/ballista/executor/src/executor.rs index d9246bfe9..addccf7a8 100644 --- a/ballista/executor/src/executor.rs +++ b/ballista/executor/src/executor.rs @@ -23,9 +23,9 @@ use crate::execution_engine::QueryStageExecutor; use crate::metrics::ExecutorMetricsCollector; use crate::metrics::LoggingMetricsCollector; use ballista_core::error::BallistaError; +use ballista_core::registry::BallistaFunctionRegistry; use ballista_core::serde::protobuf; use ballista_core::serde::protobuf::ExecutorRegistration; -use ballista_core::serde::scheduler::BallistaFunctionRegistry; use ballista_core::serde::scheduler::PartitionId; use ballista_core::ConfigProducer; use ballista_core::RuntimeProducer; diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index 9a6187bda..f3070041e 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -25,7 +25,7 @@ use std::{env, io}; use anyhow::{Context, Result}; use arrow_flight::flight_service_server::FlightServiceServer; -use ballista_core::serde::scheduler::BallistaFunctionRegistry; +use ballista_core::registry::BallistaFunctionRegistry; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::stream::FuturesUnordered; diff --git a/ballista/executor/src/standalone.rs b/ballista/executor/src/standalone.rs index ac67a5a2b..03c8a3ce1 100644 --- a/ballista/executor/src/standalone.rs +++ b/ballista/executor/src/standalone.rs @@ -19,7 +19,7 @@ use crate::metrics::LoggingMetricsCollector; use crate::{execution_loop, executor::Executor, flight_service::BallistaFlightService}; use arrow_flight::flight_service_server::FlightServiceServer; use ballista_core::config::BallistaConfig; -use ballista_core::serde::scheduler::BallistaFunctionRegistry; +use ballista_core::registry::BallistaFunctionRegistry; use ballista_core::utils::{default_config_producer, SessionConfigExt}; use ballista_core::{ error::Result,