From 04cae0131b2db522f10d792584766eb5259c3f85 Mon Sep 17 00:00:00 2001 From: JustinRush80 <69156844+JustinRush80@users.noreply.github.com> Date: Wed, 22 Jan 2025 21:24:38 -0500 Subject: [PATCH] fix merge conflict Signed-off-by: JustinRush80 <69156844+JustinRush80@users.noreply.github.com> --- crates/core/src/operations/merge/mod.rs | 115 +++++++++++++++++++++++- python/src/merge.rs | 2 - python/tests/test_generated_columns.py | 16 ++++ python/tests/test_merge.py | 3 + 4 files changed, 130 insertions(+), 6 deletions(-) diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index e8f45745af..a440c8966a 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -1623,6 +1623,7 @@ mod tests { use datafusion_expr::expr::Placeholder; use datafusion_expr::lit; use datafusion_expr::Expr; + use delta_kernel::schema::StructType; use itertools::Itertools; use regex::Regex; use serde_json::json; @@ -1793,7 +1794,7 @@ mod tests { assert_merge(table, metrics).await; } #[tokio::test] - async fn test_merge_with_schema_mode_no_change_of_schema() { + async fn test_merge_with_schema_merge_no_change_of_schema() { let (table, _) = setup().await; let schema = Arc::new(ArrowSchema::new(vec![ @@ -1883,6 +1884,97 @@ mod tests { assert_merge(after_table, metrics).await; } + #[tokio::test] + async fn test_merge_with_schema_merge_and_struct() { + let (table, _) = setup().await; + + let nested_schema = Arc::new(ArrowSchema::new(vec![Field::new( + "count", + ArrowDataType::Int64, + true, + )])); + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", ArrowDataType::Utf8, true), + Field::new("value", ArrowDataType::Int32, true), + Field::new("modified", ArrowDataType::Utf8, true), + Field::new( + "nested", + ArrowDataType::Struct(nested_schema.fields().clone()), + true, + ), + ])); + let count_array = arrow::array::Int64Array::from(vec![Some(1)]); + let id_array = arrow::array::StringArray::from(vec![Some("X")]); + let value_array = arrow::array::Int32Array::from(vec![Some(1)]); + let modified_array = arrow::array::StringArray::from(vec![Some("2021-02-02")]); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(id_array), + Arc::new(value_array), + Arc::new(modified_array), + Arc::new(arrow::array::StructArray::from( + RecordBatch::try_new(nested_schema, vec![Arc::new(count_array)]).unwrap(), + )), + ], + ) + .unwrap(); + + let ctx = SessionContext::new(); + + let source = ctx.read_batch(batch).unwrap(); + + let (table, _) = DeltaOps(table.clone()) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .with_merge_schema(true) + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + .set("nested", col("source.nested")) + }) + .unwrap() + .await + .unwrap(); + + let snapshot_bytes = table + .log_store + .read_commit_entry(2) + .await + .unwrap() + .expect("failed to get snapshot bytes"); + let actions = crate::logstore::get_actions(2, snapshot_bytes) + .await + .unwrap(); + + let schema_actions = actions + .iter() + .any(|action| matches!(action, Action::Metadata(_))); + + dbg!(&schema_actions); + + assert!(schema_actions); + let expected = vec![ + "+----+-------+------------+------------+", + "| id | value | modified | nested |", + "+----+-------+------------+------------+", + "| A | 1 | 2021-02-01 | |", + "| B | 10 | 2021-02-01 | |", + "| C | 10 | 2021-02-02 | |", + "| D | 100 | 2021-02-02 | |", + "| X | 1 | 2021-02-02 | {count: 1} |", + "+----+-------+------------+------------+", + ]; + let actual = get_data(&table).await; + + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_schema_evolution_simple_update() { let (table, _) = setup().await; @@ -1926,7 +2018,7 @@ mod tests { .unwrap(); let commit_info = table.history(None).await.unwrap(); - dbg!(&commit_info); + let last_commit = &commit_info[0]; let parameters = last_commit.operation_parameters.clone().unwrap(); assert_eq!(parameters["mergePredicate"], json!("target.id = source.id")); @@ -1941,6 +2033,8 @@ mod tests { "+----+-------+------------+-------------+", ]; let actual = get_data(&table).await; + let expected_schema_struct: StructType = Arc::clone(&schema).try_into().unwrap(); + assert_eq!(&expected_schema_struct, table.schema().unwrap()); assert_batches_sorted_eq!(&expected, &actual); } @@ -2007,6 +2101,8 @@ mod tests { "+----+-------+------------+-------------+", ]; let actual = get_data(&table).await; + let expected_schema_struct: StructType = Arc::clone(&schema).try_into().unwrap(); + assert_eq!(&expected_schema_struct, table.schema().unwrap()); assert_batches_sorted_eq!(&expected, &actual); } @@ -3012,7 +3108,7 @@ mod tests { } #[tokio::test] - async fn test_empty_table_schema_evo_merge() { + async fn test_empty_table_with_schema_merge() { let schema = Arc::new(ArrowSchema::new(vec![ Field::new("id", ArrowDataType::Utf8, true), Field::new("value", ArrowDataType::Int32, true), @@ -3092,6 +3188,8 @@ mod tests { "+----+-------+-------------+------------+", ]; let actual = get_data(&table).await; + let expected_schema_struct: StructType = Arc::clone(&schema).try_into().unwrap(); + assert_eq!(&expected_schema_struct, table.schema().unwrap()); assert_batches_sorted_eq!(&expected, &actual); } @@ -3871,6 +3969,13 @@ mod tests { assert_eq!(table.version(), 0); let schema = get_arrow_schema(&None); + + let source_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", ArrowDataType::Utf8, true), + Field::new("value", ArrowDataType::Int32, true), + Field::new("modified", ArrowDataType::Utf8, true), + Field::new("inserted_by", ArrowDataType::Utf8, true), + ])); let table = write_data(table, &schema).await; assert_eq!(table.version(), 1); @@ -3879,7 +3984,7 @@ mod tests { let source = source.with_column("inserted_by", lit("new_value")).unwrap(); let (table, _) = DeltaOps(table) - .merge(source, col("target.id").eq(col("source.id"))) + .merge(source.clone(), col("target.id").eq(col("source.id"))) .with_source_alias("source") .with_target_alias("target") .with_merge_schema(true) @@ -3918,6 +4023,8 @@ mod tests { "+----+-------+------------+-------------+", ]; let actual = get_data(&table).await; + let expected_schema_struct: StructType = source_schema.try_into().unwrap(); + assert_eq!(&expected_schema_struct, table.schema().unwrap()); assert_batches_sorted_eq!(&expected, &actual); let ctx = SessionContext::new(); diff --git a/python/src/merge.rs b/python/src/merge.rs index f4fa3baaaa..8da1aceb6f 100644 --- a/python/src/merge.rs +++ b/python/src/merge.rs @@ -9,7 +9,6 @@ use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::LazyTableProvider; use deltalake::logstore::LogStoreRef; use deltalake::operations::merge::MergeBuilder; -use deltalake::operations::write::SchemaMode; use deltalake::operations::CustomExecuteHandler; use deltalake::table::state::DeltaTableState; use deltalake::{DeltaResult, DeltaTable}; @@ -18,7 +17,6 @@ use pyo3::prelude::*; use std::collections::HashMap; use std::fmt::{self}; use std::future::IntoFuture; -use std::str::FromStr; use std::sync::{Arc, Mutex}; use crate::error::PythonError; diff --git a/python/tests/test_generated_columns.py b/python/tests/test_generated_columns.py index 239eead2c9..f5f2780ff6 100644 --- a/python/tests/test_generated_columns.py +++ b/python/tests/test_generated_columns.py @@ -201,6 +201,22 @@ def test_merge_with_gc(table_with_gc: DeltaTable, data_without_gc): ) +def test_merge_with_g_during_schema_evolution(table_with_gc: DeltaTable, data_without_gc): + ( + table_with_gc.merge( + data_without_gc, predicate="s.id = t.id", source_alias="s", target_alias="t",merge_schema=True + ) + .when_not_matched_insert_all() + .execute() + ) + id_col = pa.field("id", pa.int32()) + gc = pa.field("gc", pa.int32()) + expected_data = pa.Table.from_pydict( + {"id": [1, 2], "gc": [5, 5]}, schema=pa.schema([id_col, gc]) + ) + assert table_with_gc.to_pyarrow_table() == expected_data + + def test_merge_with_gc_invalid(table_with_gc: DeltaTable, invalid_gc_data): import re diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 11fab7f3bb..37df6e8d4f 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -188,6 +188,7 @@ def test_merge_when_matched_update_wo_predicate_with_schema_evolution( last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" + assert result.schema == expected.schema assert result == expected @pytest.mark.parametrize("streaming", (True, False)) @@ -478,6 +479,7 @@ def test_merge_when_not_matched_insert_with_predicate_schema_evolution( last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" + assert result.schema == expected.schema assert result == expected @pytest.mark.parametrize("streaming", (True, False)) @@ -609,6 +611,7 @@ def test_merge_when_not_matched_insert_all_with_exclude_and_with_schema_evo( last_action = dt.history(1)[0] assert last_action["operation"] == "MERGE" + assert result.schema == expected.schema assert result == expected @pytest.mark.parametrize("streaming", (True, False))