Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to Datafusion 43 #905

Merged
merged 19 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 199 additions & 174 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ substrait = ["dep:datafusion-substrait"]
tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync"] }
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
arrow = { version = "53", features = ["pyarrow"] }
datafusion = { version = "42.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "42.0.0", optional = true }
datafusion-proto = { version = "42.0.0" }
datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "43.0.0", optional = true }
datafusion-proto = { version = "43.0.0" }
datafusion-functions-window-common = { version = "43.0.0" }
prost = "0.13" # keep in line with `datafusion-substrait`
uuid = { version = "1.11", features = ["v4"] }
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
Expand All @@ -58,4 +59,4 @@ crate-type = ["cdylib", "rlib"]

[profile.release]
lto = true
codegen-units = 1
codegen-units = 1
4 changes: 2 additions & 2 deletions examples/tpch/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def df_selection(col_name, col_type):
if col_type == pa.float64() or isinstance(col_type, pa.Decimal128Type):
return F.round(col(col_name), lit(2)).alias(col_name)
elif col_type == pa.string():
elif col_type == pa.string() or col_type == pa.string_view():
return F.trim(col(col_name)).alias(col_name)
else:
return col(col_name)
Expand All @@ -43,7 +43,7 @@ def load_schema(col_name, col_type):
def expected_selection(col_name, col_type):
if col_type == pa.int64() or col_type == pa.int32():
return F.trim(col(col_name)).cast(col_type).alias(col_name)
elif col_type == pa.string():
elif col_type == pa.string() or col_type == pa.string_view():
return F.trim(col(col_name)).alias(col_name)
else:
return col(col_name)
Expand Down
4 changes: 2 additions & 2 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
Column = expr_internal.Column
CreateMemoryTable = expr_internal.CreateMemoryTable
CreateView = expr_internal.CreateView
CrossJoin = expr_internal.CrossJoin
Distinct = expr_internal.Distinct
DropTable = expr_internal.DropTable
EmptyRelation = expr_internal.EmptyRelation
Expand Down Expand Up @@ -140,7 +139,6 @@
"Join",
"JoinType",
"JoinConstraint",
"CrossJoin",
"Union",
"Unnest",
"UnnestExpr",
Expand Down Expand Up @@ -376,6 +374,8 @@ def literal(value: Any) -> Expr:

``value`` must be a valid PyArrow scalar value or easily castable to one.
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
if not isinstance(value, pa.Scalar):
value = pa.scalar(value)
return Expr(expr_internal.Expr.literal(value))
Expand Down
11 changes: 8 additions & 3 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def decode(input: Expr, encoding: Expr) -> Expr:

def array_to_string(expr: Expr, delimiter: Expr) -> Expr:
"""Converts each element to its text representation."""
return Expr(f.array_to_string(expr.expr, delimiter.expr))
return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string())))


def array_join(expr: Expr, delimiter: Expr) -> Expr:
Expand Down Expand Up @@ -1065,7 +1065,10 @@ def struct(*args: Expr) -> Expr:

def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr:
"""Returns a struct with the given names and arguments pairs."""
name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
name_pair_exprs = [
[Expr.literal(pa.scalar(pair[0], type=pa.string())), pair[1]]
for pair in name_pairs
]

# flatten
name_pairs = [x.expr for xs in name_pair_exprs for x in xs]
Expand Down Expand Up @@ -1422,7 +1425,9 @@ def array_sort(array: Expr, descending: bool = False, null_first: bool = False)
nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
return Expr(
f.array_sort(
array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr
array.expr,
Expr.literal(pa.scalar(desc, type=pa.string())).expr,
Expr.literal(pa.scalar(nulls_first, type=pa.string())).expr,
)
)

Expand Down
1 change: 1 addition & 0 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def udaf(
which this UDAF is used. The following examples are all valid.

.. code-block:: python

import pyarrow as pa
import pyarrow.compute as pc

Expand Down
16 changes: 12 additions & 4 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,18 @@ def test_limit(test_ctx):

plan = plan.to_variant()
assert isinstance(plan, Limit)
assert plan.skip() == 0
# TODO: Upstream now has expressions for skip and fetch
# REF: https://github.com/apache/datafusion/pull/12836
# assert plan.skip() == 0

df = test_ctx.sql("select c1 from test LIMIT 10 OFFSET 5")
plan = df.logical_plan()

plan = plan.to_variant()
assert isinstance(plan, Limit)
assert plan.skip() == 5
# TODO: Upstream now has expressions for skip and fetch
# REF: https://github.com/apache/datafusion/pull/12836
# assert plan.skip() == 5


def test_aggregate_query(test_ctx):
Expand Down Expand Up @@ -126,7 +130,10 @@ def test_relational_expr(test_ctx):
ctx = SessionContext()

batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
[
pa.array([1, 2, 3]),
pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]], name="batch_array")
Expand All @@ -141,7 +148,8 @@ def test_relational_expr(test_ctx):
assert df.filter(col("b") == "beta").count() == 1
assert df.filter(col("b") != "beta").count() == 2

assert df.filter(col("a") == "beta").count() == 0
with pytest.raises(Exception):
df.filter(col("a") == "beta").count()


def test_expr_to_variant():
Expand Down
67 changes: 47 additions & 20 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def df():
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[
pa.array(["Hello", "World", "!"]),
pa.array(["Hello", "World", "!"], type=pa.string_view()),
pa.array([4, 5, 6]),
pa.array(["hello ", " world ", " !"]),
pa.array(["hello ", " world ", " !"], type=pa.string_view()),
pa.array(
[
datetime(2022, 12, 31),
Expand Down Expand Up @@ -88,16 +88,18 @@ def test_literal(df):
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([1] * 3)
assert result.column(1) == pa.array(["1"] * 3)
assert result.column(2) == pa.array(["OK"] * 3)
assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view())
assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view())
assert result.column(3) == pa.array([3.14] * 3)
assert result.column(4) == pa.array([True] * 3)
assert result.column(5) == pa.array([b"hello world"] * 3)


def test_lit_arith(df):
"""Test literals with arithmetic operations"""
df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!")))
df = df.select(
literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!"))
)
result = df.collect()
assert len(result) == 1
result = result[0]
Expand Down Expand Up @@ -578,21 +580,33 @@ def test_array_function_obj_tests(stmt, py_expr):
f.ascii(column("a")),
pa.array([72, 87, 33], type=pa.int32()),
), # H = 72; W = 87; ! = 33
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
(
f.bit_length(column("a").cast(pa.string())),
pa.array([40, 40, 8], type=pa.int32()),
),
(
f.btrim(literal(" World ")),
pa.array(["World", "World", "World"], type=pa.string_view()),
),
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
(
f.concat_ws("-", column("a"), literal("test")),
pa.array(["Hello-test", "World-test", "!-test"]),
),
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
(
f.concat(column("a").cast(pa.string()), literal("?")),
pa.array(["Hello?", "World?", "!?"]),
),
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
(
f.ltrim(column("c")),
pa.array(["hello ", "world ", "!"], type=pa.string_view()),
),
(
f.md5(column("a")),
pa.array(
Expand All @@ -618,19 +632,25 @@ def test_array_function_obj_tests(stmt, py_expr):
f.rpad(column("a"), literal(8)),
pa.array(["Hello ", "World ", "! "]),
),
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
(
f.rtrim(column("c")),
pa.array(["hello", " world", " !"], type=pa.string_view()),
),
(
f.split_part(column("a"), literal("l"), literal(1)),
pa.array(["He", "Wor", "!"]),
),
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
(
f.substr(column("a"), literal(3)),
pa.array(["llo", "rld", ""], type=pa.string_view()),
),
(
f.translate(column("a"), literal("or"), literal("ld")),
pa.array(["Helll", "Wldld", "!"]),
),
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
(f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())),
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
(
Expand Down Expand Up @@ -772,9 +792,9 @@ def test_temporal_functions(df):
f.date_trunc(literal("month"), column("d")),
f.datetrunc(literal("day"), column("d")),
f.date_bin(
literal("15 minutes"),
literal("15 minutes").cast(pa.string()),
column("d"),
literal("2001-01-01 00:02:30"),
literal("2001-01-01 00:02:30").cast(pa.string()),
),
f.from_unixtime(literal(1673383974)),
f.to_timestamp(literal("2023-09-07 05:06:14.523952")),
Expand Down Expand Up @@ -836,8 +856,8 @@ def test_case(df):
result = df.collect()
result = result[0]
assert result.column(0) == pa.array([10, 8, 8])
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"])
assert result.column(2) == pa.array(["Hola", "Mundo", None])
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view())
assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view())


def test_when_with_no_base(df):
Expand All @@ -855,8 +875,10 @@ def test_when_with_no_base(df):
result = df.collect()
result = result[0]
assert result.column(0) == pa.array([4, 5, 6])
assert result.column(1) == pa.array(["too small", "just right", "too big"])
assert result.column(2) == pa.array(["Hello", None, None])
assert result.column(1) == pa.array(
["too small", "just right", "too big"], type=pa.string_view()
)
assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view())


def test_regr_funcs_sql(df):
Expand Down Expand Up @@ -999,8 +1021,13 @@ def test_regr_funcs_df(func, expected):

def test_binary_string_functions(df):
df = df.select(
f.encode(column("a"), literal("base64")),
f.decode(f.encode(column("a"), literal("base64")), literal("base64")),
f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
f.decode(
f.encode(
column("a").cast(pa.string()), literal("base64").cast(pa.string())
),
literal("base64").cast(pa.string()),
),
)
result = df.collect()
assert len(result) == 1
Expand Down
2 changes: 0 additions & 2 deletions python/tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
Join,
JoinType,
JoinConstraint,
CrossJoin,
Union,
Like,
ILike,
Expand Down Expand Up @@ -129,7 +128,6 @@ def test_class_module_is_datafusion():
Join,
JoinType,
JoinConstraint,
CrossJoin,
Union,
Like,
ILike,
Expand Down
7 changes: 7 additions & 0 deletions python/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,13 @@ def test_simple_select(ctx, tmp_path, arr):
batches = ctx.sql("SELECT a AS tt FROM t").collect()
result = batches[0].column(0)

# In DF 43.0.0 we now default to having BinaryView and StringView
# so the array that is saved to the parquet is slightly different
# than the array read. Convert to values for comparison.
if isinstance(result, pa.BinaryViewArray) or isinstance(result, pa.StringViewArray):
arr = arr.tolist()
result = result.tolist()

np.testing.assert_equal(result, arr)


Expand Down
2 changes: 1 addition & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ impl PySessionContext {
} else {
RuntimeConfig::default()
};
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
let runtime = Arc::new(RuntimeEnv::try_new(runtime_config)?);
let session_state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
Expand Down
8 changes: 6 additions & 2 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ impl PyDataFrame {

#[pyo3(signature = (column, preserve_nulls=true))]
fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyResult<Self> {
let unnest_options = UnnestOptions { preserve_nulls };
// TODO: expose RecursionUnnestOptions
// REF: https://github.com/apache/datafusion/pull/11577
let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
let df = self
.df
.as_ref()
Expand All @@ -420,7 +422,9 @@ impl PyDataFrame {

#[pyo3(signature = (columns, preserve_nulls=true))]
fn unnest_columns(&self, columns: Vec<String>, preserve_nulls: bool) -> PyResult<Self> {
let unnest_options = UnnestOptions { preserve_nulls };
// TODO: expose RecursionUnnestOptions
// REF: https://github.com/apache/datafusion/pull/11577
let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
let df = self
.df
Expand Down
2 changes: 0 additions & 2 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ pub mod column;
pub mod conditional_expr;
pub mod create_memory_table;
pub mod create_view;
pub mod cross_join;
pub mod distinct;
pub mod drop_table;
pub mod empty_relation;
Expand Down Expand Up @@ -775,7 +774,6 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<join::PyJoin>()?;
m.add_class::<join::PyJoinType>()?;
m.add_class::<join::PyJoinConstraint>()?;
m.add_class::<cross_join::PyCrossJoin>()?;
m.add_class::<union::PyUnion>()?;
m.add_class::<unnest::PyUnnest>()?;
m.add_class::<unnest_expr::PyUnnestExpr>()?;
Expand Down
Loading
Loading