diff --git a/py/server/deephaven/agg.py b/py/server/deephaven/agg.py index 45f7c4668ab..8de070d32af 100644 --- a/py/server/deephaven/agg.py +++ b/py/server/deephaven/agg.py @@ -119,7 +119,9 @@ def count_where(col: str, filters: Union[str, Filter, Sequence[str], Sequence[Fi filters. Args: - col (str): the column to hold the counts of each distinct group + col (str): the column to hold the counts of rows that pass the filter condition + filters (Union[str, Filter, Sequence[str], Sequence[Filter]], optional): the filter condition + expression(s) or Filter object(s) Returns: an aggregation diff --git a/py/server/deephaven/updateby.py b/py/server/deephaven/updateby.py index 8c68f5ffbe7..39553c7b9ce 100644 --- a/py/server/deephaven/updateby.py +++ b/py/server/deephaven/updateby.py @@ -3,13 +3,14 @@ # """This module supports building various operations for use with the update-by Table operation.""" from enum import Enum -from typing import Union, List +from typing import Union, List, Sequence import jpy from deephaven import DHError from deephaven._wrapper import JObjectWrapper from deephaven.jcompat import to_sequence +from deephaven.filters import Filter, and_ _JUpdateByOperation = jpy.get_type("io.deephaven.api.updateby.UpdateByOperation") _JBadDataBehavior = jpy.get_type("io.deephaven.api.updateby.BadDataBehavior") @@ -563,6 +564,31 @@ def cum_max(cols: Union[str, List[str]]) -> UpdateByOperation: raise DHError(e, "failed to create a cumulative maximum UpdateByOperation.") from e +def cum_count_where(col: str, filters: Union[str, Filter, Sequence[str], Sequence[Filter]]) -> UpdateByOperation: + """Creates a cumulative count where UpdateByOperation that counts the number of values that pass the provided + filters. + + Args: + col (str): the column to hold the counts of rows that pass the filter condition columns. + filters (Union[str, Filter, Sequence[str], Sequence[Filter]], optional): the filter condition + expression(s) or Filter object(s) + + Returns: + an UpdateByOperation + + Raises: + DHError + """ + if not isinstance(col, str): + raise DHError(message="count_where aggregation requires a string value for the 'col' argument.") + filters = to_sequence(filters) + + try: + return UpdateByOperation(j_updateby_op=_JUpdateByOperation.CumCountWhere(col, and_(filters).j_filter)) + except Exception as e: + raise DHError(e, "failed to create a cumulative count_where UpdateByOperation.") from e + + def forward_fill(cols: Union[str, List[str]]) -> UpdateByOperation: """Creates a forward fill UpdateByOperation for the supplied column names. Null values in the columns are replaced by the last known non-null values. This operation is forward only. @@ -1494,4 +1520,96 @@ def rolling_formula_time(ts_col: str, formula: str, formula_param: str = None, c cols = to_sequence(cols) return UpdateByOperation(j_updateby_op=_JUpdateByOperation.RollingFormula(ts_col, rev_time, fwd_time, formula, formula_param, *cols)) except Exception as e: - raise DHError(e, "failed to create a rolling formula (time) UpdateByOperation.") from e \ No newline at end of file + raise DHError(e, "failed to create a rolling formula (time) UpdateByOperation.") from e + + +def rolling_count_where_tick(col: str, filters: Union[str, Filter, Sequence[str], Sequence[Filter]], + rev_ticks: int = 0, fwd_ticks: int = 0) -> UpdateByOperation: + """Creates a rolling count where UpdateByOperation that counts the number of values that pass the provided + filters, using ticks as the windowing unit. Ticks are row counts, and you may specify the reverse and forward + window in number of rows to include. The current row is considered to belong to the reverse window but not the + forward window. Also, negative values are allowed and can be used to generate completely forward or completely + reverse windows. + + Here are some examples of window values: + | `rev_ticks = 1, fwd_ticks = 0` - contains only the current row + | `rev_ticks = 10, fwd_ticks = 0` - contains 9 previous rows and the current row + | `rev_ticks = 0, fwd_ticks = 10` - contains the following 10 rows, excludes the current row + | `rev_ticks = 10, fwd_ticks = 10` - contains the previous 9 rows, the current row and the 10 rows following + | `rev_ticks = 10, fwd_ticks = -5` - contains 5 rows, beginning at 9 rows before, ending at 5 rows before the + current row (inclusive) + | `rev_ticks = 11, fwd_ticks = -1` - contains 10 rows, beginning at 10 rows before, ending at 1 row before the + current row (inclusive) + | `rev_ticks = -5, fwd_ticks = 10` - contains 5 rows, beginning 5 rows following, ending at 10 rows following the + current row (inclusive) + + Args: + col (str): the column to hold the counts of rows that pass the filter condition columns. + filters (Union[str, Filter, Sequence[str], Sequence[Filter]], optional): the filter condition + expression(s) or Filter object(s) + rev_ticks (int): the look-behind window size (in rows/ticks) + fwd_ticks (int): the look-forward window size (int rows/ticks), default is 0 + + Returns: + an UpdateByOperation + + Raises: + DHError + """ + if not isinstance(col, str): + raise DHError(message="count_where aggregation requires a string value for the 'col' argument.") + filters = to_sequence(filters) + + try: + return UpdateByOperation(j_updateby_op=_JUpdateByOperation.RollingCountWhere(rev_ticks, fwd_ticks, col, and_(filters).j_filter)) + except Exception as e: + raise DHError(e, "failed to create a rolling count_where UpdateByOperation.") from e + + +def rolling_count_where_time(ts_col: str, col: str, filters: Union[str, Filter, Sequence[str], Sequence[Filter]], + rev_time: Union[int, str] = 0, fwd_time: Union[int, str] = 0) -> UpdateByOperation: + """Creates a rolling count where UpdateByOperation that counts the number of values that pass the provided + filters, using time as the windowing unit. This function accepts nanoseconds or time strings as the reverse and + forward window parameters. Negative values are allowed and can be used to generate completely forward or completely + reverse windows. A row containing a null in the timestamp column belongs to no window and will not be considered in + the windows of other rows; its output will be null. + + Here are some examples of window values: + | `rev_time = 0, fwd_time = 0` - contains rows that exactly match the current row timestamp + | `rev_time = "PT00:10:00", fwd_time = "0"` - contains rows from 10m before through the current row timestamp ( + inclusive) + | `rev_time = 0, fwd_time = 600_000_000_000` - contains rows from the current row through 10m following the + current row timestamp (inclusive) + | `rev_time = "PT00:10:00", fwd_time = "PT00:10:00"` - contains rows from 10m before through 10m following + the current row timestamp (inclusive) + | `rev_time = "PT00:10:00", fwd_time = "-PT00:05:00"` - contains rows from 10m before through 5m before the + current row timestamp (inclusive), this is a purely backwards looking window + | `rev_time = "-PT00:05:00", fwd_time = "PT00:10:00"` - contains rows from 5m following through 10m + following the current row timestamp (inclusive), this is a purely forwards looking window + + Args: + ts_col (str): the timestamp column for determining the window + col (str): the column to hold the counts of rows that pass the filter condition columns. + filters (Union[str, Filter, Sequence[str], Sequence[Filter]], optional): the filter condition + expression(s) or Filter object(s) + rev_time (int): the look-behind window size, can be expressed as an integer in nanoseconds or a time + interval string, e.g. "PT00:00:00.001" or "PT5M" + fwd_time (int): the look-ahead window size, can be expressed as an integer in nanoseconds or a time + interval string, e.g. "PT00:00:00.001" or "PT5M", default is 0 + + Returns: + an UpdateByOperation + + Raises: + DHError + """ + if not isinstance(col, str): + raise DHError(message="count_where aggregation requires a string value for the 'col' argument.") + filters = to_sequence(filters) + + try: + rev_time = _JDateTimeUtils.parseDurationNanos(rev_time) if isinstance(rev_time, str) else rev_time + fwd_time = _JDateTimeUtils.parseDurationNanos(fwd_time) if isinstance(fwd_time, str) else fwd_time + return UpdateByOperation(j_updateby_op=_JUpdateByOperation.RollingCountWhere(ts_col, rev_time, fwd_time, col, and_(filters).j_filter)) + except Exception as e: + raise DHError(e, "failed to create a rolling count_where UpdateByOperation.") from e diff --git a/py/server/tests/test_updateby.py b/py/server/tests/test_updateby.py index 150d41d51e5..e15127bb3af 100644 --- a/py/server/tests/test_updateby.py +++ b/py/server/tests/test_updateby.py @@ -4,13 +4,15 @@ import unittest -from deephaven import read_csv, time_table, update_graph +from deephaven import read_csv, time_table, update_graph, empty_table from deephaven.updateby import BadDataBehavior, MathContext, OperationControl, DeltaControl, ema_tick, ema_time, \ ems_tick, ems_time, emmin_tick, emmin_time, emmax_tick, emmax_time, emstd_tick, emstd_time,\ cum_sum, cum_prod, cum_min, cum_max, forward_fill, delta, rolling_sum_tick, rolling_sum_time, \ rolling_group_tick, rolling_group_time, rolling_avg_tick, rolling_avg_time, rolling_min_tick, rolling_min_time, \ rolling_max_tick, rolling_max_time, rolling_prod_tick, rolling_prod_time, rolling_count_tick, rolling_count_time, \ - rolling_std_tick, rolling_std_time, rolling_wavg_tick, rolling_wavg_time, rolling_formula_tick, rolling_formula_time + rolling_std_tick, rolling_std_time, rolling_wavg_tick, rolling_wavg_time, rolling_formula_tick, rolling_formula_time, \ + cum_count_where, rolling_count_where_tick, rolling_count_where_time +from deephaven.pandas import to_pandas from tests.testbase import BaseTestCase from deephaven.execution_context import get_exec_ctx, make_user_exec_ctx @@ -81,6 +83,12 @@ def setUpClass(cls) -> None: delta(cols=simple_op_pairs, delta_control=DeltaControl.ZERO_DOMINATES), ] + cls.simple_ops_one_output = [ + cum_count_where(col='count_1', filters='a > 5'), + cum_count_where(col='count_2', filters='a > 0 && a < 5'), + cum_count_where(col='count_3', filters=['a > 0', 'a < 5']), + ] + # Rolling Operators list shared with test_rolling_ops / test_rolling_ops_proxy cls.rolling_ops = [ # rolling sum @@ -168,6 +176,11 @@ def setUpClass(cls) -> None: rolling_formula_time(formula="formula_be=sum(b) + sum(e)", ts_col="Timestamp", rev_time="PT00:00:10"), rolling_formula_time(formula="formula_be=avg(b) + avg(e)", ts_col="Timestamp", rev_time=10_000_000_000, fwd_time=-10_000_000_00), rolling_formula_time(formula="formula_be=sum(b) + sum(b)", ts_col="Timestamp", rev_time="PT30S", fwd_time="-PT00:00:20"), + rolling_count_where_tick(col="count_1", filters="a > 50", rev_ticks=10), + rolling_count_where_tick(col="count_2", filters=["a > 0", "a <= 50"], rev_ticks=10, fwd_ticks=10), + rolling_count_where_time(col="count_3", filters="a > 50", ts_col="Timestamp", rev_time="PT00:00:10"), + rolling_count_where_time(col="count_4", filters="a > 0 && a <= 50", ts_col="Timestamp", rev_time=10_000_000_000, fwd_time=-10_000_000_00), + rolling_count_where_time(col="count_5", filters="a < 0 || a > 50", ts_col="Timestamp", rev_time="PT30S", fwd_time="-PT00:00:20"), ] @@ -232,6 +245,34 @@ def test_simple_ops_proxy(self): with update_graph.exclusive_lock(self.test_update_graph): self.assertEqual(ct.size, rct.size) + def test_simple_ops_one_output(self): + for op in self.simple_ops_one_output: + with self.subTest(op): + for t in (self.static_table, self.ticking_table): + rt = t.update_by(ops=op, by="e") + self.assertTrue(rt.is_refreshing is t.is_refreshing) + self.assertEqual(len(rt.definition), 1 + len(t.definition)) + with update_graph.exclusive_lock(self.test_update_graph): + self.assertEqual(rt.size, t.size) + + def test_simple_ops_one_output_proxy(self): + pt_proxies = [self.static_table.partition_by("c").proxy(), + self.ticking_table.partition_by("c").proxy(), + ] + + for op in self.simple_ops_one_output: + with self.subTest(op): + for pt_proxy in pt_proxies: + rt_proxy = pt_proxy.update_by(ops=op, by="e") + + self.assertTrue(rt_proxy.is_refreshing is pt_proxy.is_refreshing) + self.assertEqual(len(rt_proxy.target.constituent_table_columns), + 1 + len(pt_proxy.target.constituent_table_columns)) + + for ct, rct in zip(pt_proxy.target.constituent_tables, rt_proxy.target.constituent_tables): + with update_graph.exclusive_lock(self.test_update_graph): + self.assertEqual(ct.size, rct.size) + def test_rolling_ops(self): # Test rolling operators that produce 2 output columns for op in self.rolling_ops: @@ -293,5 +334,55 @@ def test_multiple_ops(self): with update_graph.exclusive_lock(self.test_update_graph): self.assertEqual(rt.size, t.size) + def test_cum_count_where_output(self): + """ + Test and validation of the cum_count_where feature + """ + test_table = empty_table(4).update(["a=ii", "b=ii%2"]) + count_aggs = [ + cum_count_where(col="count1", filters="a >= 1"), + cum_count_where(col="count2", filters="a >= 1 && b == 0"), + ] + result_table = test_table.update_by(ops=count_aggs) + self.assertEqual(result_table.size, 4) + + # get the table as a local pandas dataframe + df = to_pandas(result_table) + # assert the values meet expectations + self.assertEqual(df.loc[0, "count1"], 0) + self.assertEqual(df.loc[1, "count1"], 1) + self.assertEqual(df.loc[2, "count1"], 2) + self.assertEqual(df.loc[3, "count1"], 3) + + self.assertEqual(df.loc[0, "count2"], 0) + self.assertEqual(df.loc[1, "count2"], 0) + self.assertEqual(df.loc[2, "count2"], 1) + self.assertEqual(df.loc[3, "count2"], 1) + + def test_rolling_count_where_output(self): + """ + Test and validation of the cum_count_where feature + """ + test_table = empty_table(4).update(["a=ii", "b=ii%2"]) + count_aggs = [ + rolling_count_where_tick(col="count1", filters="a >= 1", rev_ticks=2), + rolling_count_where_tick(col="count2", filters="a >= 1 && b == 0", rev_ticks=2), + ] + result_table = test_table.update_by(ops=count_aggs) + self.assertEqual(result_table.size, 4) + + # get the table as a local pandas dataframe + df = to_pandas(result_table) + # assert the values meet expectations + self.assertEqual(df.loc[0, "count1"], 0) + self.assertEqual(df.loc[1, "count1"], 1) + self.assertEqual(df.loc[2, "count1"], 2) + self.assertEqual(df.loc[3, "count1"], 2) + + self.assertEqual(df.loc[0, "count2"], 0) + self.assertEqual(df.loc[1, "count2"], 0) + self.assertEqual(df.loc[2, "count2"], 1) + self.assertEqual(df.loc[3, "count2"], 1) + if __name__ == '__main__': unittest.main()