diff --git a/storey/sources.py b/storey/sources.py index 795ee5b4..b6c217da 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -982,8 +982,14 @@ class ParquetSource(DataframeSource): :parameter end_filter: datetime. If not None, the results will be filtered by partitions 'filter_column' <= end_filter. Default is None. :parameter filter_column: Optional. if not None, the results will be filtered by this column and before and/or after + datetime column. :param key_field: column to be used as key for events. can be list of columns :param id_field: column to be used as ID for events. + :param additional_filters: other filters to use while reading the parquet. + Supported operators: '=', '>=', '<=', '>', '<'. + Example: ('Product', '=', 'Computer')] + For all supported filters, please see: + https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html """ def __init__( @@ -993,6 +999,7 @@ def __init__( start_filter: Optional[datetime] = None, end_filter: Optional[datetime] = None, filter_column: Optional[str] = None, + additional_filters: Optional[list[tuple]] = None, **kwargs, ): if start_filter or end_filter: @@ -1009,6 +1016,13 @@ def __init__( if filter_column is None: raise TypeError("Filter column is required when passing start/end filters") + additional_filters = additional_filters or [] + + if not all(isinstance(item, tuple) for item in additional_filters): + raise ValueError( + f"ParquetSource supports additional_filters only as a list of tuples." + f" Current additional_filters: {additional_filters}" + ) self._paths = paths if isinstance(paths, str): @@ -1018,6 +1032,7 @@ def __init__( self._end_filter = end_filter self._filter_column = filter_column self._storage_options = kwargs.get("storage_options") + self._additional_filters = additional_filters super().__init__([], **kwargs) def _read_filtered_parquet(self, path): @@ -1032,6 +1047,8 @@ def _read_filtered_parquet(self, path): filters, self._filter_column, ) + if filters and self._additional_filters: + filters[0] += self._additional_filters try: return pandas.read_parquet( path, @@ -1058,6 +1075,8 @@ def _read_filtered_parquet(self, path): filters, self._filter_column, ) + if filters and self._additional_filters: + filters[0] += self._additional_filters return pandas.read_parquet( path, @@ -1070,7 +1089,7 @@ def _init(self): super()._init() self._dfs = [] for path in self._paths: - if self._start_filter or self._end_filter: + if self._start_filter or self._end_filter or self._additional_filters: df = self._read_filtered_parquet(path) else: df = pandas.read_parquet(path, columns=self._columns, storage_options=self._storage_options) diff --git a/tests/test_flow.py b/tests/test_flow.py index 390d73f0..fd8afc6c 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -17,6 +17,7 @@ import math import os import queue +import tempfile import time import traceback import uuid @@ -4580,3 +4581,42 @@ def test_empty_filter_result(): pd.testing.assert_frame_equal(read_back_result, pd.DataFrame({})) finally: os.remove(path) + + +@pytest.mark.parametrize("include_datetime_filter", [True, False]) +def test_filter_by_filters(include_datetime_filter): + columns = ["my_string", "my_time", "my_city"] + tel_aviv_data = [ + ["dina", pd.Timestamp("2019-07-01 00:00:00"), "tel aviv"], + ["uri", pd.Timestamp("2018-12-30 09:00:00"), "tel aviv"], + ] + df = pd.DataFrame( + [ + *tel_aviv_data, + ["katya", pd.Timestamp("2020-12-31 14:00:00"), "hod hasharon"], + ], + columns=columns, + ) + with tempfile.TemporaryDirectory() as temp_dir: + df.to_parquet(temp_dir, partition_cols=["my_city"]) + source_kwargs = {"additional_filters": [("my_city", "=", "tel aviv")]} + expected_df = pd.DataFrame(tel_aviv_data, columns=columns) + if include_datetime_filter: + source_kwargs["start_filter"] = pd.Timestamp("2019-01-01 00:00:00") + source_kwargs["end_filter"] = pd.Timestamp("2021-01-01 00:00:00") + source_kwargs["filter_column"] = "my_time" + expected_df = pd.DataFrame([tel_aviv_data[0]], columns=columns) + expected_df.set_index("my_string", inplace=True) + + controller = build_flow([ParquetSource(temp_dir, **source_kwargs), ReduceToDataFrame(index="my_string")]).run() + read_back_result = controller.await_termination() + pd.testing.assert_frame_equal(read_back_result, expected_df) + + +def test_filters_type(): + with pytest.raises(ValueError, match="ParquetSource supports additional_filters only as a list of tuples."): + ParquetSource( + "/my_dir", + additional_filters=[[("city", "=", "Tel Aviv")], [("age", ">=", "40")]], + filter_column="start_time", + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index b03e26c0..75599197 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -40,3 +40,6 @@ def test_find_filters(): filters = [] find_filters([], None, datetime.datetime.max, filters, "time") assert filters == [[("time", "<=", datetime.datetime.max)]] + filters = [] + find_filters([], None, None, filters, None) + assert filters == [[]]