Skip to content

Commit

Permalink
feat: add cardinality function to calculate total elements in an array
Browse files Browse the repository at this point in the history
  • Loading branch information
kosiew committed Oct 29, 2024
1 parent 0bc2f31 commit 0ce6cec
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
14 changes: 14 additions & 0 deletions docs/source/user-guide/common-operations/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,20 @@ This function returns a boolean indicating whether the array is empty.
In this example, the `is_empty` column will contain `True` for the first row and `False` for the second row.

To get the total number of elements in an array, you can use the function :py:func:`datafusion.functions.cardinality`.
This function returns an integer indicating the total number of elements in the array.

.. ipython:: python
from datafusion import SessionContext, col
from datafusion.functions import cardinality
ctx = SessionContext()
df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5, 6]]})
df.select(cardinality(col("a")).alias("num_elements"))
In this example, the `num_elements` column will contain `3` for both rows.

Structs
-------

Expand Down
6 changes: 6 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
"find_in_set",
"first_value",
"flatten",
"cardinality",
"floor",
"from_unixtime",
"gcd",
Expand Down Expand Up @@ -1516,6 +1517,11 @@ def flatten(array: Expr) -> Expr:
return Expr(f.flatten(array.expr))


def cardinality(array: Expr) -> Expr:
"""Returns the total number of elements in the array."""
return Expr(f.cardinality(array.expr))


# aggregate functions
def approx_distinct(
expression: Expr,
Expand Down
18 changes: 18 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,24 @@ def test_array_function_flatten():
)


def test_array_function_cardinality():
data = [[1, 2, 3], [4, 4, 5, 6]]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
df = ctx.create_dataframe([[batch]])

stmt = f.cardinality(column("arr"))
py_expr = [len(arr) for arr in data] # Expected lengths: [3, 3]
# assert py_expr lengths

query_result = df.select(stmt).collect()[0].column(0)

for a, b in zip(query_result, py_expr):
np.testing.assert_array_equal(
np.array([a.as_py()], dtype=int), np.array([b], dtype=int)
)


@pytest.mark.parametrize(
("stmt", "py_expr"),
[
Expand Down
2 changes: 2 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ array_fn!(array_intersect, first_array second_array);
array_fn!(array_union, array1 array2);
array_fn!(array_except, first_array second_array);
array_fn!(array_resize, array size value);
array_fn!(cardinality, array);
array_fn!(flatten, array);
array_fn!(range, start stop step);

Expand Down Expand Up @@ -1030,6 +1031,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_sort))?;
m.add_wrapped(wrap_pyfunction!(array_slice))?;
m.add_wrapped(wrap_pyfunction!(flatten))?;
m.add_wrapped(wrap_pyfunction!(cardinality))?;

// Window Functions
m.add_wrapped(wrap_pyfunction!(lead))?;
Expand Down

0 comments on commit 0ce6cec

Please sign in to comment.