Skip to content

Commit

Permalink
[Sources] Add filters to ParquetSource (#514)
Browse files Browse the repository at this point in the history
* add filters to parquetsource + tests + double usage validator

* change error message

* fix test

* halfway

* finished filters in parquetsource + validators + tests.

* merge all filters validators to 1 test.

* lint - there is a different between local lint and ci lint...

* split input error tests.

* fix message

* rename filters name

* rename var + add docs.

* fix PR comments

* lint

* fix test.

* remove default values

* supports datatime in additional_filters

* Update storey/sources.py

Co-authored-by: Gal Topper <[email protected]>

* remove double usage of filter column limitation

* remove combine_filters.

---------

Co-authored-by: Gal Topper <[email protected]>
  • Loading branch information
tomerm-iguazio and gtopper authored Apr 25, 2024
1 parent 5c9ff31 commit 19d0d11
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
21 changes: 20 additions & 1 deletion storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,8 +982,14 @@ class ParquetSource(DataframeSource):
:parameter end_filter: datetime. If not None, the results will be filtered by partitions
'filter_column' <= end_filter. Default is None.
:parameter filter_column: Optional. if not None, the results will be filtered by this column and before and/or after
datetime column.
:param key_field: column to be used as key for events. can be list of columns
:param id_field: column to be used as ID for events.
:param additional_filters: other filters to use while reading the parquet.
Supported operators: '=', '>=', '<=', '>', '<'.
Example: ('Product', '=', 'Computer')]
For all supported filters, please see:
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html
"""

def __init__(
Expand All @@ -993,6 +999,7 @@ def __init__(
start_filter: Optional[datetime] = None,
end_filter: Optional[datetime] = None,
filter_column: Optional[str] = None,
additional_filters: Optional[list[tuple]] = None,
**kwargs,
):
if start_filter or end_filter:
Expand All @@ -1009,6 +1016,13 @@ def __init__(

if filter_column is None:
raise TypeError("Filter column is required when passing start/end filters")
additional_filters = additional_filters or []

if not all(isinstance(item, tuple) for item in additional_filters):
raise ValueError(
f"ParquetSource supports additional_filters only as a list of tuples."
f" Current additional_filters: {additional_filters}"
)

self._paths = paths
if isinstance(paths, str):
Expand All @@ -1018,6 +1032,7 @@ def __init__(
self._end_filter = end_filter
self._filter_column = filter_column
self._storage_options = kwargs.get("storage_options")
self._additional_filters = additional_filters
super().__init__([], **kwargs)

def _read_filtered_parquet(self, path):
Expand All @@ -1032,6 +1047,8 @@ def _read_filtered_parquet(self, path):
filters,
self._filter_column,
)
if filters and self._additional_filters:
filters[0] += self._additional_filters
try:
return pandas.read_parquet(
path,
Expand All @@ -1058,6 +1075,8 @@ def _read_filtered_parquet(self, path):
filters,
self._filter_column,
)
if filters and self._additional_filters:
filters[0] += self._additional_filters

return pandas.read_parquet(
path,
Expand All @@ -1070,7 +1089,7 @@ def _init(self):
super()._init()
self._dfs = []
for path in self._paths:
if self._start_filter or self._end_filter:
if self._start_filter or self._end_filter or self._additional_filters:
df = self._read_filtered_parquet(path)
else:
df = pandas.read_parquet(path, columns=self._columns, storage_options=self._storage_options)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import math
import os
import queue
import tempfile
import time
import traceback
import uuid
Expand Down Expand Up @@ -4580,3 +4581,42 @@ def test_empty_filter_result():
pd.testing.assert_frame_equal(read_back_result, pd.DataFrame({}))
finally:
os.remove(path)


@pytest.mark.parametrize("include_datetime_filter", [True, False])
def test_filter_by_filters(include_datetime_filter):
columns = ["my_string", "my_time", "my_city"]
tel_aviv_data = [
["dina", pd.Timestamp("2019-07-01 00:00:00"), "tel aviv"],
["uri", pd.Timestamp("2018-12-30 09:00:00"), "tel aviv"],
]
df = pd.DataFrame(
[
*tel_aviv_data,
["katya", pd.Timestamp("2020-12-31 14:00:00"), "hod hasharon"],
],
columns=columns,
)
with tempfile.TemporaryDirectory() as temp_dir:
df.to_parquet(temp_dir, partition_cols=["my_city"])
source_kwargs = {"additional_filters": [("my_city", "=", "tel aviv")]}
expected_df = pd.DataFrame(tel_aviv_data, columns=columns)
if include_datetime_filter:
source_kwargs["start_filter"] = pd.Timestamp("2019-01-01 00:00:00")
source_kwargs["end_filter"] = pd.Timestamp("2021-01-01 00:00:00")
source_kwargs["filter_column"] = "my_time"
expected_df = pd.DataFrame([tel_aviv_data[0]], columns=columns)
expected_df.set_index("my_string", inplace=True)

controller = build_flow([ParquetSource(temp_dir, **source_kwargs), ReduceToDataFrame(index="my_string")]).run()
read_back_result = controller.await_termination()
pd.testing.assert_frame_equal(read_back_result, expected_df)


def test_filters_type():
with pytest.raises(ValueError, match="ParquetSource supports additional_filters only as a list of tuples."):
ParquetSource(
"/my_dir",
additional_filters=[[("city", "=", "Tel Aviv")], [("age", ">=", "40")]],
filter_column="start_time",
)
3 changes: 3 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ def test_find_filters():
filters = []
find_filters([], None, datetime.datetime.max, filters, "time")
assert filters == [[("time", "<=", datetime.datetime.max)]]
filters = []
find_filters([], None, None, filters, None)
assert filters == [[]]

0 comments on commit 19d0d11

Please sign in to comment.