Skip to content

Commit

Permalink
refactor: SessionBuilder to return Result<_> (#1138)
Browse files Browse the repository at this point in the history
* refactor: SessionBuilder to return Result<_>

* Update ballista/core/src/utils.rs

Co-authored-by: Andy Grove <[email protected]>

---------

Co-authored-by: Andy Grove <[email protected]>
  • Loading branch information
milenkovicm and andygrove authored Nov 27, 2024
1 parent 683dede commit 020d29d
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 27 deletions.
10 changes: 6 additions & 4 deletions ballista/core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ use tonic::codegen::StdError;
use tonic::transport::{Channel, Error, Server};

/// Default session builder using the provided configuration
pub fn default_session_builder(config: SessionConfig) -> SessionState {
SessionStateBuilder::new()
pub fn default_session_builder(
config: SessionConfig,
) -> datafusion::common::Result<SessionState> {
Ok(SessionStateBuilder::new()
.with_default_features()
.with_config(config)
.with_runtime_env(Arc::new(RuntimeEnv::new(RuntimeConfig::default()).unwrap()))
.build()
.with_runtime_env(Arc::new(RuntimeEnv::new(RuntimeConfig::default())?))
.build())
}

pub fn default_config_producer() -> SessionConfig {
Expand Down
4 changes: 2 additions & 2 deletions ballista/scheduler/src/cluster/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ impl JobState for InMemoryJobState {
&self,
config: &SessionConfig,
) -> Result<Arc<SessionContext>> {
let session = create_datafusion_context(config, self.session_builder.clone());
let session = create_datafusion_context(config, self.session_builder.clone())?;
self.sessions.insert(session.session_id(), session.clone());

Ok(session)
Expand All @@ -419,7 +419,7 @@ impl JobState for InMemoryJobState {
session_id: &str,
config: &SessionConfig,
) -> Result<Arc<SessionContext>> {
let session = create_datafusion_context(config, self.session_builder.clone());
let session = create_datafusion_context(config, self.session_builder.clone())?;
self.sessions
.insert(session_id.to_string(), session.clone());

Expand Down
3 changes: 2 additions & 1 deletion ballista/scheduler/src/scheduler_server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ mod external_scaler;
mod grpc;
pub(crate) mod query_stage_scheduler;

pub type SessionBuilder = Arc<dyn Fn(SessionConfig) -> SessionState + Send + Sync>;
pub type SessionBuilder =
Arc<dyn Fn(SessionConfig) -> datafusion::common::Result<SessionState> + Send + Sync>;

#[derive(Clone)]
pub struct SchedulerServer<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
Expand Down
8 changes: 5 additions & 3 deletions ballista/scheduler/src/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ pub async fn new_standalone_scheduler_from_state(
let session_config = session_state.config().clone();
let session_state = session_state.clone();
let session_builder = Arc::new(move |c: SessionConfig| {
SessionStateBuilder::new_from_existing(session_state.clone())
.with_config(c)
.build()
Ok(
SessionStateBuilder::new_from_existing(session_state.clone())
.with_config(c)
.build(),
)
});

let config_producer = Arc::new(move || session_config.clone());
Expand Down
8 changes: 4 additions & 4 deletions ballista/scheduler/src/state/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,18 @@ impl SessionManager {
pub fn create_datafusion_context(
session_config: &SessionConfig,
session_builder: SessionBuilder,
) -> Arc<SessionContext> {
) -> datafusion::common::Result<Arc<SessionContext>> {
let session_state = if session_config.round_robin_repartition() {
let session_config = session_config
.clone()
// should we disable catalog on the scheduler side
.with_round_robin_repartition(false);

log::warn!("session manager will override `datafusion.optimizer.enable_round_robin_repartition` to `false` ");
session_builder(session_config)
session_builder(session_config)?
} else {
session_builder(session_config.clone())
session_builder(session_config.clone())?
};

Arc::new(SessionContext::new_with_state(session_state))
Ok(Arc::new(SessionContext::new_with_state(session_state)))
}
2 changes: 1 addition & 1 deletion examples/examples/custom-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async fn main() -> Result<()> {

// new sessions state with required custom session configuration and runtime environment
let state =
custom_session_state_with_s3_support(custom_session_config_with_s3_options());
custom_session_state_with_s3_support(custom_session_config_with_s3_options())?;

let ctx: SessionContext =
SessionContext::remote_with_state("df://localhost:50050", state).await?;
Expand Down
8 changes: 4 additions & 4 deletions examples/src/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ pub fn custom_runtime_env_with_s3_support(
/// and [RuntimeEnv].
pub fn custom_session_state_with_s3_support(
session_config: SessionConfig,
) -> SessionState {
let runtime_env = custom_runtime_env_with_s3_support(&session_config).unwrap();
) -> datafusion::common::Result<SessionState> {
let runtime_env = custom_runtime_env_with_s3_support(&session_config)?;

SessionStateBuilder::new()
Ok(SessionStateBuilder::new()
.with_runtime_env(runtime_env)
.with_config(session_config)
.build()
.build())
}

/// Custom [ObjectStoreRegistry] which will create
Expand Down
17 changes: 9 additions & 8 deletions examples/tests/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ mod custom_s3_config {
// object store registry.

let session_builder = Arc::new(produce_state);
let state = session_builder(config_producer());
let state = session_builder(config_producer())?;

// setting up ballista cluster with new runtime, configuration, and session state producers
let (host, port) = crate::common::setup_test_cluster_with_builders(
Expand Down Expand Up @@ -416,7 +416,7 @@ mod custom_s3_config {
// object store registry.

let session_builder = Arc::new(produce_state);
let state = session_builder(config_producer());
let state = session_builder(config_producer())?;

// // establishing cluster connection,
let ctx: SessionContext = SessionContext::standalone_with_state(state).await?;
Expand Down Expand Up @@ -480,24 +480,25 @@ mod custom_s3_config {
Ok(())
}

fn produce_state(session_config: SessionConfig) -> SessionState {
fn produce_state(
session_config: SessionConfig,
) -> datafusion::common::Result<SessionState> {
let s3options = session_config
.options()
.extensions
.get::<S3Options>()
.ok_or(DataFusionError::Configuration(
"S3 Options not set".to_string(),
))
.unwrap();
))?;

let config = RuntimeConfig::new().with_object_store_registry(Arc::new(
CustomObjectStoreRegistry::new(s3options.clone()),
));
let runtime_env = RuntimeEnv::new(config).unwrap();
let runtime_env = RuntimeEnv::new(config)?;

SessionStateBuilder::new()
Ok(SessionStateBuilder::new()
.with_runtime_env(runtime_env.into())
.with_config(session_config)
.build()
.build())
}
}

0 comments on commit 020d29d

Please sign in to comment.