diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index c4c74886bd..dbee635910 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -48,18 +48,23 @@ use datafusion::execution::context::{SessionConfig, SessionContext, SessionState use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; use datafusion::physical_optimizer::pruning::PruningPredicate; +use datafusion_common::project_schema; use datafusion_common::scalar::ScalarValue; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::{ config::ConfigOptions, Column, DFSchema, DataFusionError, Result as DataFusionResult, TableReference, ToDFSchema, }; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::CreateExternalTable; use datafusion_expr::utils::conjunction; use datafusion_expr::{col, Expr, Extension, LogicalPlan, TableProviderFilterPushDown, Volatility}; +use datafusion_physical_expr::{create_physical_expr, create_physical_exprs, PhysicalExpr}; use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::limit::LocalLimitExec; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, Statistics, @@ -71,7 +76,9 @@ use either::Either; use futures::TryStreamExt; use itertools::Itertools; use object_store::ObjectMeta; +use parking_lot::RwLock; use serde::{Deserialize, Serialize}; + use url::Url; use crate::delta_datafusion::expr::parse_predicate_expression; @@ -839,6 +846,107 @@ impl TableProvider for DeltaTableProvider { } } +#[derive(Debug)] +pub struct LazyTableProvider { + schema: Arc, + batches: Vec>>, +} + +impl LazyTableProvider { + /// Build a DeltaTableProvider + pub fn try_new( + schema: Arc, + batches: Vec>>, + ) -> DeltaResult { + Ok(LazyTableProvider { schema, batches }) + } +} + +#[async_trait] +impl TableProvider for LazyTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn get_table_definition(&self) -> Option<&str> { + None + } + + fn get_logical_plan(&self) -> Option> { + None + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> DataFusionResult> { + let mut plan: Arc = Arc::new(LazyMemoryExec::try_new( + self.schema(), + self.batches.clone(), + )?); + + let df_schema: DFSchema = plan.schema().try_into()?; + + if let Some(filter_expr) = conjunction(filters.iter().cloned()) { + let physical_expr = + create_physical_expr(&filter_expr, &df_schema, &ExecutionProps::new())?; + plan = Arc::new(FilterExec::try_new(physical_expr, plan)?); + } + + if let Some(projection) = projection { + let current_projection = (0..plan.schema().fields().len()).collect::>(); + if projection != ¤t_projection { + let execution_props = &ExecutionProps::new(); + let fields: DeltaResult, String)>> = projection + .into_iter() + .map(|i| { + let (table_ref, field) = df_schema.qualified_field(*i); + create_physical_expr( + &Expr::Column(Column::from((table_ref, field))), + &df_schema, + execution_props, + ) + .map(|expr| (expr, field.name().clone())) + .map_err(DeltaTableError::from) + }) + .collect(); + plan = Arc::new(ProjectionExec::try_new(fields?, plan)?); + } + } + + if let Some(limit) = limit { + plan = Arc::new(GlobalLimitExec::new(plan, 0, Some(limit))) + }; + + Ok(plan) + } + + fn supports_filters_pushdown( + &self, + filter: &[&Expr], + ) -> DataFusionResult> { + Ok(filter + .iter() + .map(|_| TableProviderFilterPushDown::Inexact) + .collect()) + } + + fn statistics(&self) -> Option { + None + } +} + // TODO: this will likely also need to perform column mapping later when we support reader protocol v2 /// A wrapper for parquet scans #[derive(Debug)] diff --git a/crates/core/src/operations/merge/filter.rs b/crates/core/src/operations/merge/filter.rs index 602df519a1..39d8e1586d 100644 --- a/crates/core/src/operations/merge/filter.rs +++ b/crates/core/src/operations/merge/filter.rs @@ -182,36 +182,39 @@ pub(crate) fn generalize_filter( source_name: &TableReference, target_name: &TableReference, placeholders: &mut Vec, + streaming_source: bool, ) -> Option { match predicate { Expr::BinaryExpr(binary) => { - if references_table(&binary.right, source_name).has_reference() { - if let ReferenceTableCheck::HasReference(left_target) = - references_table(&binary.left, target_name) - { - return construct_placeholder( - binary, - false, - partition_columns.contains(&left_target), - left_target, - placeholders, - ); + if !streaming_source { + if references_table(&binary.right, source_name).has_reference() { + if let ReferenceTableCheck::HasReference(left_target) = + references_table(&binary.left, target_name) + { + return construct_placeholder( + binary, + false, + partition_columns.contains(&left_target), + left_target, + placeholders, + ); + } + return None; } - return None; - } - if references_table(&binary.left, source_name).has_reference() { - if let ReferenceTableCheck::HasReference(right_target) = - references_table(&binary.right, target_name) - { - return construct_placeholder( - binary, - true, - partition_columns.contains(&right_target), - right_target, - placeholders, - ); + if references_table(&binary.left, source_name).has_reference() { + if let ReferenceTableCheck::HasReference(right_target) = + references_table(&binary.right, target_name) + { + return construct_placeholder( + binary, + true, + partition_columns.contains(&right_target), + right_target, + placeholders, + ); + } + return None; } - return None; } let left = generalize_filter( @@ -220,6 +223,7 @@ pub(crate) fn generalize_filter( source_name, target_name, placeholders, + streaming_source, ); let right = generalize_filter( *binary.right, @@ -227,6 +231,7 @@ pub(crate) fn generalize_filter( source_name, target_name, placeholders, + streaming_source, ); match (left, right) { @@ -258,6 +263,7 @@ pub(crate) fn generalize_filter( source_name, target_name, placeholders, + streaming_source, )?; let mut list_expr = Vec::new(); @@ -272,6 +278,7 @@ pub(crate) fn generalize_filter( source_name, target_name, placeholders, + streaming_source, ) { list_expr.push(item) } @@ -291,19 +298,23 @@ pub(crate) fn generalize_filter( } other => match references_table(&other, source_name) { ReferenceTableCheck::HasReference(col) => { - let placeholder_name = format!("{col}_{}", placeholders.len()); - - let placeholder = Expr::Placeholder(Placeholder { - id: placeholder_name.clone(), - data_type: None, - }); - - placeholders.push(PredicatePlaceholder { - expr: other, - alias: placeholder_name, - is_aggregate: true, - }); - Some(placeholder) + if !streaming_source { + let placeholder_name = format!("{col}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + + placeholders.push(PredicatePlaceholder { + expr: other, + alias: placeholder_name, + is_aggregate: true, + }); + Some(placeholder) + } else { + None + } } ReferenceTableCheck::NoReference => Some(other), ReferenceTableCheck::Unknown => None, @@ -318,6 +329,7 @@ pub(crate) async fn try_construct_early_filter( source: &LogicalPlan, source_name: &TableReference, target_name: &TableReference, + streaming_source: bool, ) -> DeltaResult> { let table_metadata = table_snapshot.metadata(); let partition_columns = &table_metadata.partition_columns; @@ -330,10 +342,11 @@ pub(crate) async fn try_construct_early_filter( source_name, target_name, &mut placeholders, + streaming_source, ) { None => Ok(None), Some(filter) => { - if placeholders.is_empty() { + if placeholders.is_empty() || streaming_source { // if we haven't recognised any source predicates in the join predicate, return our filter with static only predicates Ok(Some(filter)) } else { @@ -382,7 +395,6 @@ pub(crate) async fn try_construct_early_filter( } } } - #[cfg(test)] mod tests { use crate::operations::merge::tests::setup_table; @@ -457,6 +469,7 @@ mod tests { &source, &source_name, &target_name, + false, ) .await .unwrap(); @@ -554,6 +567,7 @@ mod tests { &source, &source_name, &target_name, + false, ) .await .unwrap(); @@ -632,6 +646,7 @@ mod tests { &source, &source_name, &target_name, + false, ) .await .unwrap(); @@ -711,6 +726,7 @@ mod tests { &source_plan, &source_name, &target_name, + false, ) .await .unwrap(); @@ -807,6 +823,7 @@ mod tests { &source_plan, &source_name, &target_name, + false, ) .await .unwrap(); @@ -908,6 +925,7 @@ mod tests { &source_plan, &source_name, &target_name, + false, ) .await .unwrap(); diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 70087fb3b3..77a294db78 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -130,6 +130,8 @@ pub struct MergeBuilder { snapshot: DeltaTableState, /// The source data source: DataFrame, + /// Whether the source is a streaming source (if true, stats deducing to prune target is disabled) + streaming: bool, /// Delta object store for handling data files log_store: LogStoreRef, /// Datafusion session state relevant for executing the input plan @@ -176,6 +178,7 @@ impl MergeBuilder { not_match_operations: Vec::new(), not_match_source_operations: Vec::new(), safe_cast: false, + streaming: false, custom_execute_handler: None, } } @@ -397,6 +400,12 @@ impl MergeBuilder { self } + /// Set streaming mode execution + pub fn with_streaming(mut self, streaming: bool) -> Self { + self.streaming = streaming; + self + } + /// Set a custom execute handler, for pre and post execution pub fn with_custom_execute_handler(mut self, handler: Arc) -> Self { self.custom_execute_handler = Some(handler); @@ -705,6 +714,7 @@ async fn execute( writer_properties: Option, mut commit_properties: CommitProperties, _safe_cast: bool, + streaming: bool, source_alias: Option, target_alias: Option, match_operations: Vec, @@ -870,7 +880,8 @@ async fn execute( // Attempt to construct an early filter that we can apply to the Add action list and the delta scan. // In the case where there are partition columns in the join predicate, we can scan the source table // to get the distinct list of partitions affected and constrain the search to those. - let target_subset_filter = if !not_match_source_operations.is_empty() { + + let target_subset_filter: Option = if !not_match_source_operations.is_empty() { // It's only worth trying to create an early filter where there are no `when_not_matched_source` operators, since // that implies a full scan None @@ -882,6 +893,7 @@ async fn execute( &source, &source_name, &target_name, + streaming, ) .await? }; @@ -1450,6 +1462,7 @@ impl std::future::IntoFuture for MergeBuilder { this.writer_properties, this.commit_properties, this.safe_cast, + this.streaming, this.source_alias, this.target_alias, this.match_operations, @@ -2632,6 +2645,7 @@ mod tests { &source, &target, &mut placeholders, + false, ) .unwrap(); @@ -2664,6 +2678,7 @@ mod tests { &source, &target, &mut placeholders, + false, ) .unwrap(); @@ -2708,6 +2723,7 @@ mod tests { &source, &target, &mut placeholders, + false, ) .unwrap(); @@ -2747,6 +2763,7 @@ mod tests { &source, &target, &mut placeholders, + false, ) .unwrap(); @@ -2777,6 +2794,7 @@ mod tests { &source, &target, &mut placeholders, + false, ) .unwrap(); let expected_filter_l = Expr::Placeholder(Placeholder { @@ -2810,6 +2828,7 @@ mod tests { &source, &target, &mut placeholders, + false, ) .unwrap(); diff --git a/python/Cargo.toml b/python/Cargo.toml index 977466250c..ec416c32fc 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -35,6 +35,7 @@ regex = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } uuid = { workspace = true, features = ["serde", "v4"] } +parking_lot = "0.12" # runtime futures = { workspace = true } diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index f19c685118..5045f00020 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -196,6 +196,7 @@ class RawDeltaTable: commit_properties: Optional[CommitProperties], post_commithook_properties: Optional[PostCommitHookProperties], safe_cast: bool, + streaming: bool, ) -> PyMergeBuilder: ... def merge_execute(self, merge_builder: PyMergeBuilder) -> str: ... def get_active_partitions( diff --git a/python/deltalake/table.py b/python/deltalake/table.py index f8357c3700..4a699f7c85 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1029,14 +1029,17 @@ def merge( convert_pyarrow_table, ) + streaming = False if isinstance(source, pyarrow.RecordBatchReader): source = convert_pyarrow_recordbatchreader(source, conversion_mode) + streaming = True elif isinstance(source, pyarrow.RecordBatch): source = convert_pyarrow_recordbatch(source, conversion_mode) elif isinstance(source, pyarrow.Table): source = convert_pyarrow_table(source, conversion_mode) elif isinstance(source, ds.Dataset): source = convert_pyarrow_dataset(source, conversion_mode) + streaming = True elif _has_pandas and isinstance(source, pd.DataFrame): source = convert_pyarrow_table( pyarrow.Table.from_pandas(source), conversion_mode @@ -1056,6 +1059,7 @@ def merge( source_alias=source_alias, target_alias=target_alias, safe_cast=not error_on_type_mismatch, + streaming=streaming, writer_properties=writer_properties, commit_properties=commit_properties, post_commithook_properties=post_commithook_properties, diff --git a/python/src/lib.rs b/python/src/lib.rs index 7905087a7b..4ea992b782 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -895,6 +895,7 @@ impl RawDeltaTable { source_alias = None, target_alias = None, safe_cast = false, + streaming = false, writer_properties = None, post_commithook_properties = None, commit_properties = None, @@ -907,6 +908,7 @@ impl RawDeltaTable { source_alias: Option, target_alias: Option, safe_cast: bool, + streaming: bool, writer_properties: Option, post_commithook_properties: Option, commit_properties: Option, @@ -927,6 +929,7 @@ impl RawDeltaTable { source_alias, target_alias, safe_cast, + streaming, writer_properties, post_commithook_properties, commit_properties, diff --git a/python/src/merge.rs b/python/src/merge.rs index a2ff75a6d1..7238b7c345 100644 --- a/python/src/merge.rs +++ b/python/src/merge.rs @@ -4,16 +4,20 @@ use deltalake::arrow::ffi_stream::ArrowArrayStreamReader; use deltalake::arrow::pyarrow::IntoPyArrow; use deltalake::datafusion::catalog::TableProvider; use deltalake::datafusion::datasource::MemTable; +use deltalake::datafusion::physical_plan::memory::LazyBatchGenerator; use deltalake::datafusion::prelude::SessionContext; +use deltalake::delta_datafusion::LazyTableProvider; use deltalake::logstore::LogStoreRef; use deltalake::operations::merge::MergeBuilder; use deltalake::operations::CustomExecuteHandler; use deltalake::table::state::DeltaTableState; use deltalake::{DeltaResult, DeltaTable}; +use parking_lot::RwLock; use pyo3::prelude::*; use std::collections::HashMap; +use std::fmt::{self}; use std::future::IntoFuture; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use crate::error::PythonError; use crate::utils::rt; @@ -31,6 +35,46 @@ pub(crate) struct PyMergeBuilder { target_alias: Option, arrow_schema: Arc, } +#[derive(Debug)] +struct ArrowStreamBatchGenerator { + pub array_stream: Arc>, +} + +impl fmt::Display for ArrowStreamBatchGenerator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ArrowStreamBatchGenerator {{ array_stream: {:?} }}", + self.array_stream + ) + } +} + +impl ArrowStreamBatchGenerator { + fn new(array_stream: Arc>) -> Self { + Self { array_stream } + } +} + +impl LazyBatchGenerator for ArrowStreamBatchGenerator { + fn generate_next_batch( + &mut self, + ) -> deltalake::datafusion::error::Result> { + let mut stream_reader = self.array_stream.lock().map_err(|_| { + deltalake::datafusion::error::DataFusionError::Execution( + "Failed to lock the ArrowArrayStreamReader".to_string(), + ) + })?; + + match stream_reader.next() { + Some(Ok(record_batch)) => Ok(Some(record_batch)), + Some(Err(err)) => Err(deltalake::datafusion::error::DataFusionError::ArrowError( + err, None, + )), + None => Ok(None), // End of stream + } + } +} impl PyMergeBuilder { #[allow(clippy::too_many_arguments)] @@ -42,6 +86,7 @@ impl PyMergeBuilder { source_alias: Option, target_alias: Option, safe_cast: bool, + streaming: bool, writer_properties: Option, post_commithook_properties: Option, commit_properties: Option, @@ -49,13 +94,27 @@ impl PyMergeBuilder { ) -> DeltaResult { let ctx = SessionContext::new(); let schema = source.schema(); - let batches = vec![source.map(|batch| batch.unwrap()).collect::>()]; - let table_provider: Arc = - Arc::new(MemTable::try_new(schema.clone(), batches).unwrap()); - let source_df = ctx.read_table(table_provider).unwrap(); - let mut cmd = - MergeBuilder::new(log_store, snapshot, predicate, source_df).with_safe_cast(safe_cast); + let source_df = if streaming { + let arrow_stream: Arc> = Arc::new(Mutex::new(source)); + let arrow_stream_batch_generator: Arc> = + Arc::new(RwLock::new(ArrowStreamBatchGenerator::new(arrow_stream))); + + let table_provider: Arc = Arc::new(LazyTableProvider::try_new( + schema.clone(), + vec![arrow_stream_batch_generator], + )?); + ctx.read_table(table_provider).unwrap() + } else { + let batches = vec![source.map(|batch| batch.unwrap()).collect::>()]; + let table_provider: Arc = + Arc::new(MemTable::try_new(schema.clone(), batches).unwrap()); + ctx.read_table(table_provider).unwrap() + }; + + let mut cmd = MergeBuilder::new(log_store, snapshot, predicate, source_df) + .with_safe_cast(safe_cast) + .with_streaming(streaming); if let Some(src_alias) = &source_alias { cmd = cmd.with_source_alias(src_alias); diff --git a/python/tests/test_generated_columns.py b/python/tests/test_generated_columns.py index b329c948be..239eead2c9 100644 --- a/python/tests/test_generated_columns.py +++ b/python/tests/test_generated_columns.py @@ -196,7 +196,9 @@ def test_merge_with_gc(table_with_gc: DeltaTable, data_without_gc): expected_data = pa.Table.from_pydict( {"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc]) ) - assert table_with_gc.to_pyarrow_table() == expected_data + assert ( + table_with_gc.to_pyarrow_table().sort_by([("id", "ascending")]) == expected_data + ) def test_merge_with_gc_invalid(table_with_gc: DeltaTable, invalid_gc_data): diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 69eb73ebc6..a09fdecc5f 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -9,11 +9,13 @@ from deltalake import DeltaTable, write_deltalake from deltalake.exceptions import DeltaProtocolError +from deltalake.schema import ArrowSchemaConversionMode, convert_pyarrow_table from deltalake.table import CommitProperties +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_matched_delete_wo_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -26,6 +28,11 @@ def test_merge_when_matched_delete_wo_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + commit_properties = CommitProperties(custom_metadata={"userName": "John Doe"}) dt.merge( source=source_table, @@ -52,8 +59,9 @@ def test_merge_when_matched_delete_wo_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_matched_delete_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -69,6 +77,11 @@ def test_merge_when_matched_delete_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, predicate="t.id = s.id", @@ -92,8 +105,9 @@ def test_merge_when_matched_delete_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_matched_update_wo_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -107,6 +121,11 @@ def test_merge_when_matched_update_wo_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, predicate="t.id = s.id", @@ -129,8 +148,9 @@ def test_merge_when_matched_update_wo_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_matched_update_all_wo_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -146,6 +166,11 @@ def test_merge_when_matched_update_all_wo_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, predicate="t.id = s.id", @@ -168,8 +193,9 @@ def test_merge_when_matched_update_all_wo_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_matched_update_all_with_exclude( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -185,6 +211,11 @@ def test_merge_when_matched_update_all_with_exclude( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, predicate="t.id = s.id", @@ -207,8 +238,9 @@ def test_merge_when_matched_update_all_with_exclude( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_matched_update_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -223,6 +255,11 @@ def test_merge_when_matched_update_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -248,8 +285,9 @@ def test_merge_when_matched_update_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_insert_wo_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -264,6 +302,11 @@ def test_merge_when_not_matched_insert_wo_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -293,8 +336,9 @@ def test_merge_when_not_matched_insert_wo_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_insert_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -309,6 +353,11 @@ def test_merge_when_not_matched_insert_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -339,8 +388,9 @@ def test_merge_when_not_matched_insert_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_insert_all_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -355,6 +405,11 @@ def test_merge_when_not_matched_insert_all_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -379,8 +434,9 @@ def test_merge_when_not_matched_insert_all_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_insert_all_with_exclude( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -395,6 +451,11 @@ def test_merge_when_not_matched_insert_all_with_exclude( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -417,8 +478,9 @@ def test_merge_when_not_matched_insert_all_with_exclude( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_insert_all_with_predicate_special_column_names( - tmp_path: pathlib.Path, sample_table_with_spaces_numbers: pa.Table + tmp_path: pathlib.Path, sample_table_with_spaces_numbers: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table_with_spaces_numbers, mode="append") @@ -433,6 +495,11 @@ def test_merge_when_not_matched_insert_all_with_predicate_special_column_names( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -457,8 +524,9 @@ def test_merge_when_not_matched_insert_all_with_predicate_special_column_names( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_by_source_update_wo_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -473,6 +541,11 @@ def test_merge_when_not_matched_by_source_update_wo_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -499,8 +572,9 @@ def test_merge_when_not_matched_by_source_update_wo_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_by_source_update_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -515,6 +589,11 @@ def test_merge_when_not_matched_by_source_update_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -542,8 +621,9 @@ def test_merge_when_not_matched_by_source_update_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_by_source_delete_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -558,6 +638,11 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -583,8 +668,9 @@ def test_merge_when_not_matched_by_source_delete_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_when_not_matched_by_source_delete_wo_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -594,6 +680,11 @@ def test_merge_when_not_matched_by_source_delete_wo_predicate( {"id": pa.array(["4", "5"]), "weight": pa.array([1.5, 1.6], pa.float64())} ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -619,8 +710,9 @@ def test_merge_when_not_matched_by_source_delete_wo_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_multiple_when_matched_update_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -635,6 +727,11 @@ def test_merge_multiple_when_matched_update_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -663,8 +760,9 @@ def test_merge_multiple_when_matched_update_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_multiple_when_matched_update_all_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -679,6 +777,11 @@ def test_merge_multiple_when_matched_update_all_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -705,8 +808,9 @@ def test_merge_multiple_when_matched_update_all_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_multiple_when_not_matched_insert_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -721,6 +825,11 @@ def test_merge_multiple_when_not_matched_insert_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, source_alias="source", @@ -759,8 +868,9 @@ def test_merge_multiple_when_not_matched_insert_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_multiple_when_matched_delete_with_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): write_deltalake(tmp_path, sample_table, mode="append") @@ -776,6 +886,11 @@ def test_merge_multiple_when_matched_delete_with_predicate( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + dt.merge( source=source_table, predicate="t.id = s.id", @@ -801,8 +916,9 @@ def test_merge_multiple_when_matched_delete_with_predicate( assert result == expected +@pytest.mark.parametrize("streaming", (True, False)) def test_merge_multiple_when_not_matched_by_source_update_wo_predicate( - tmp_path: pathlib.Path, sample_table: pa.Table + tmp_path: pathlib.Path, sample_table: pa.Table, streaming: bool ): """The first match clause that meets the predicate will be executed, so if you do an update without an other predicate the first clause will be matched and therefore the other ones are skipped. @@ -819,6 +935,10 @@ def test_merge_multiple_when_not_matched_by_source_update_wo_predicate( "deleted": pa.array([False, False]), } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) dt.merge( source=source_table, @@ -850,7 +970,8 @@ def test_merge_multiple_when_not_matched_by_source_update_wo_predicate( assert result == expected -def test_merge_date_partitioned_2344(tmp_path: pathlib.Path): +@pytest.mark.parametrize("streaming", (True, False)) +def test_merge_date_partitioned_2344(tmp_path: pathlib.Path, streaming: bool): from datetime import date schema = pa.schema( @@ -873,6 +994,11 @@ def test_merge_date_partitioned_2344(tmp_path: pathlib.Path): } ) + expected = data + + if streaming: + data = convert_pyarrow_table(data, ArrowSchemaConversionMode.PASSTHROUGH) + dt.merge( data, predicate="s.date = t.date", @@ -884,11 +1010,15 @@ def test_merge_date_partitioned_2344(tmp_path: pathlib.Path): last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" - assert result == data - assert ( - last_action["operationParameters"].get("predicate") - == "'2022-02-01'::date = date" - ) + assert result == expected + if not streaming: + assert ( + last_action["operationParameters"].get("predicate") + == "'2022-02-01'::date = date" + ) + else: + # In streaming mode we don't use aggregated stats of the source in the predicate + assert last_action["operationParameters"].get("predicate") is None @pytest.mark.parametrize( @@ -944,8 +1074,11 @@ def test_merge_timestamps_partitioned_2344(tmp_path: pathlib.Path, timezone, pre assert last_action["operationParameters"].get("predicate") == predicate +@pytest.mark.parametrize("streaming", (True, False)) @pytest.mark.parametrize("engine", ["pyarrow", "rust"]) -def test_merge_stats_columns_stats_provided(tmp_path: pathlib.Path, engine): +def test_merge_stats_columns_stats_provided( + tmp_path: pathlib.Path, engine, streaming: bool +): data = pa.table( { "foo": pa.array(["a", "b", None, None]), @@ -982,6 +1115,9 @@ def test_merge_stats_columns_stats_provided(tmp_path: pathlib.Path, engine): } ) + if streaming: + data = convert_pyarrow_table(data, ArrowSchemaConversionMode.PASSTHROUGH) + dt.merge( data, predicate="source.foo = target.foo", @@ -1074,9 +1210,8 @@ def test_struct_casting(tmp_path: pathlib.Path): assert result is not None -def test_merge_isin_partition_pruning( - tmp_path: pathlib.Path, -): +@pytest.mark.parametrize("streaming", (True, False)) +def test_merge_isin_partition_pruning(tmp_path: pathlib.Path, streaming: bool): nrows = 5 data = pa.table( { @@ -1098,6 +1233,11 @@ def test_merge_isin_partition_pruning( } ) + if streaming: + source_table = convert_pyarrow_table( + source_table, ArrowSchemaConversionMode.PASSTHROUGH + ) + metrics = ( dt.merge( source=source_table, @@ -1125,7 +1265,8 @@ def test_merge_isin_partition_pruning( assert metrics["num_target_files_skipped_during_scan"] == 3 -def test_cdc_merge_planning_union_2908(tmp_path): +@pytest.mark.parametrize("streaming", (True, False)) +def test_cdc_merge_planning_union_2908(tmp_path, streaming: bool): """https://github.com/delta-io/delta-rs/issues/2908""" cdc_path = f"{tmp_path}/_change_data" @@ -1148,6 +1289,9 @@ def test_cdc_merge_planning_union_2908(tmp_path): }, ) + if streaming: + table = convert_pyarrow_table(table, ArrowSchemaConversionMode.PASSTHROUGH) + dt.merge( source=table, predicate="s.id = t.id",