Skip to content

Commit

Permalink
add support for multiple key fields (#186)
Browse files Browse the repository at this point in the history
* add support for multiple key fields

* PR comments

* pr comment
  • Loading branch information
katyakats authored Mar 18, 2021
1 parent 92aecb8 commit bb3651d
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 36 deletions.
50 changes: 49 additions & 1 deletion integration/test_aggregation_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime, timedelta
import pytest
import math
import pandas as pd

from storey import build_flow, Source, Reduce, Table, V3ioDriver, MapWithState, AggregateByKey, FieldAggregator, \
QueryByKey, WriteToTable, Context
QueryByKey, WriteToTable, Context, DataframeSource

from storey.dtypes import SlidingWindows, FixedWindows
from storey.utils import _split_path
Expand Down Expand Up @@ -1129,3 +1130,50 @@ def test_write_to_table_reuse(setup_teardown_test):
controller.terminate()
actual = controller.await_termination()
assert actual == expected_results[iteration]


def test_aggregate_multiple_keys(setup_teardown_test):
current_time = pd.Timestamp.now()
data = pd.DataFrame(
{
"first_name": ["moshe", "yosi", "yosi"],
"last_name": ["cohen", "levi", "levi"],
"some_data": [1, 2, 3],
"time": [current_time - pd.Timedelta(minutes=25), current_time - pd.Timedelta(minutes=30), current_time - pd.Timedelta(minutes=35)]
}
)

keys = ['first_name', 'last_name']
table = Table(setup_teardown_test, V3ioDriver())
controller = build_flow([
DataframeSource(data, key_field=keys),
AggregateByKey([FieldAggregator("number_of_stuff", "some_data", ["sum"],
SlidingWindows(['1h'], '10m'))],
table),
WriteToTable(table),
]).run()

actual = controller.await_termination()

other_table = Table(setup_teardown_test, V3ioDriver())
controller = build_flow([
Source(),
QueryByKey(["number_of_stuff_sum_1h"],
other_table, keys=["first_name", "last_name"]),
Reduce([], lambda acc, x: append_return(acc, x)),
]).run()

controller.emit({'first_name': 'moshe', 'last_name': 'cohen', 'some_data': 4}, ['moshe', 'cohen'])
controller.emit({'first_name': 'moshe', 'last_name': 'levi', 'some_data': 5}, ['moshe', 'levi'])
controller.emit({'first_name': 'yosi', 'last_name': 'levi', 'some_data': 6}, ['yosi', 'levi'])

controller.terminate()
actual = controller.await_termination()
expected_results = [
{'number_of_stuff_sum_1h': 1.0, 'first_name': 'moshe', 'last_name': 'cohen', 'some_data': 4},
{'number_of_stuff_sum_1h': 0, 'first_name': 'moshe', 'last_name': 'levi', 'some_data': 5},
{'number_of_stuff_sum_1h': 5.0, 'first_name': 'yosi', 'last_name': 'levi', 'some_data': 6}
]

assert actual == expected_results, \
f'actual did not match expected. \n actual: {actual} \n expected: {expected_results}'
49 changes: 48 additions & 1 deletion integration/test_flow_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import v3io_frames as frames

from storey import Filter, JoinWithV3IOTable, SendToHttp, Map, Reduce, Source, HttpRequest, build_flow, \
WriteToV3IOStream, V3ioDriver, WriteToTSDB, Table, JoinWithTable, MapWithState, WriteToTable, DataframeSource
WriteToV3IOStream, V3ioDriver, WriteToTSDB, Table, JoinWithTable, MapWithState, WriteToTable, DataframeSource, ReduceToDataFrame, \
QueryByKey, AggregateByKey, ReadCSV
from storey.utils import hash_list
from .integration_test_utils import V3ioHeaders, append_return, test_base_time, setup_kv_teardown_test, setup_teardown_test, \
setup_stream_teardown_test

Expand Down Expand Up @@ -380,6 +382,11 @@ def enrich(event, state):

async def get_kv_item(full_path, key):
try:
if isinstance(key, list):
if len(key) >= 3:
key = key[0] + "." + str(hash_list(key[1:]))
else:
key = key[0] + "." + key[1]
headers = V3ioHeaders()
container, path = full_path.split('/', 1)

Expand Down Expand Up @@ -415,3 +422,43 @@ def test_writing_timedelta_key(setup_teardown_test):
]).run()
controller.await_termination()


def test_write_multiple_keys_to_v3io_from_df(setup_teardown_test):
table = Table(setup_teardown_test, V3ioDriver())
data = pd.DataFrame(
{
"first_name": ["moshe", "yosi"],
"last_name": ["cohen", "levi"],
"city": ["tel aviv", "ramat gan"],
}
)

keys = ['first_name', 'last_name']
controller = build_flow([
DataframeSource(data, key_field=keys),
WriteToTable(table),
]).run()
controller.await_termination()

response = asyncio.run(get_kv_item(setup_teardown_test, ['moshe', 'cohen']))
expected = {'city': 'tel aviv', 'first_name': 'moshe', 'last_name': 'cohen'}
assert response.status_code == 200
assert expected == response.output.item


def test_write_multiple_keys_to_v3io_from_csv(setup_teardown_test):
table = Table(setup_teardown_test, V3ioDriver())

controller = build_flow([
ReadCSV('tests/test.csv', header=True, key_field=['n1', 'n2'], build_dict=True),
WriteToTable(table),
]).run()
controller.await_termination()

response = asyncio.run(get_kv_item(setup_teardown_test, '1.2'))

expected = {'n1': 1, 'n2': 2, 'n3': 3}

assert response.status_code == 200
assert expected == response.output.item

6 changes: 4 additions & 2 deletions storey/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def f(element, features):
self.key_extractor = key
elif isinstance(key, str):
self.key_extractor = lambda element: element.get(key)
elif isinstance(key, list):
self.key_extractor = lambda element: [element.get(single_key) for single_key in key]
else:
raise TypeError(f'key is expected to be either a callable or string but got {type(key)}')

Expand Down Expand Up @@ -218,12 +220,12 @@ class QueryByKey(AggregateByKey):
:param table: A Table object or name for persistence of aggregations.
If a table name is provided, it will be looked up in the context object passed in kwargs.
:param key: Key field to aggregate by, accepts either a string representing the key field or a key extracting function.
Defaults to the key in the event's metadata. (Optional)
Defaults to the key in the event's metadata. (Optional). Can be list of keys
:param augmentation_fn: Function that augments the features into the event's body. Defaults to updating a dict. (Optional)
:param aliases: Dictionary specifying aliases to the enriched columns, of the format `{'col_name': 'new_col_name'}`. (Optional)
"""

def __init__(self, features: List[str], table: Union[Table, str], key: Union[str, Callable[[Event], object], None] = None,
def __init__(self, features: List[str], table: Union[Table, str], key: Union[str, List[str], Callable[[Event], object], None] = None,
augmentation_fn: Optional[Callable[[Event, Dict[str, object]], Event]] = None,
aliases: Optional[Dict[str, str]] = None, **kwargs):
self._aggrs = []
Expand Down
6 changes: 3 additions & 3 deletions storey/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Event:
"""The basic unit of data in storey. All steps receive and emit events.
:param body: the event payload, or data
:param key: Event key. Used by steps that aggregate events by key, such as AggregateByKey. (Optional)
:param key: Event key. Used by steps that aggregate events by key, such as AggregateByKey. (Optional). Can be list
:param time: Event time. Defaults to the time the event was created, UTC. (Optional)
:param id: Event identifier. Usually a unique identifier. (Optional)
:param headers: Request headers (HTTP only) (Optional)
Expand All @@ -24,7 +24,7 @@ class Event:
:type awaitable_result: AwaitableResult (Optional)
"""

def __init__(self, body: object, key: Optional[str] = None, time: Optional[datetime] = None, id: Optional[str] = None,
def __init__(self, body: object, key: Optional[Union[str, List[str]]] = None, time: Optional[datetime] = None, id: Optional[str] = None,
headers: Optional[dict] = None, method: Optional[str] = None, path: Optional[str] = '/',
content_type=None, awaitable_result=None):
self.body = body
Expand All @@ -48,7 +48,7 @@ def __eq__(self, other):
self.method == other.method and self.path == other.path and self.content_type == other.content_type # noqa: E127

def __str__(self):
return f'Event(id={self.id}, key={self.key}, time={self.time}, body={self.body})'
return f'Event(id={self.id}, key={str(self.key)}, time={self.time}, body={self.body})'

def copy(self, body=None, key=None, time=None, id=None, headers=None, method=None, path=None, content_type=None,
awaitable_result=None,
Expand Down
39 changes: 25 additions & 14 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,22 +685,27 @@ async def _worker(self):
event = job[0]
completed = await job[1]

for event in self._pending_by_key[event.key].in_flight:
if isinstance(event.key, list):
event_key = str(event.key)
else:
event_key = event.key

for event in self._pending_by_key[event_key].in_flight:
await self._handle_completed(event, completed)
self._pending_by_key[event.key].in_flight = []
self._pending_by_key[event_key].in_flight = []

# If we got more pending events for the same key process them
if self._pending_by_key[event.key].pending:
self._pending_by_key[event.key].in_flight = self._pending_by_key[event.key].pending
self._pending_by_key[event.key].pending = []
if self._pending_by_key[event_key].pending:
self._pending_by_key[event_key].in_flight = self._pending_by_key[event_key].pending
self._pending_by_key[event_key].pending = []

task = self._safe_process_events(self._pending_by_key[event.key].in_flight)
task = self._safe_process_events(self._pending_by_key[event_key].in_flight)
tail_position = received_job_count + self._q.qsize()
jobs_at_tail = self_sent_jobs.get(tail_position, [])
jobs_at_tail.append((event, asyncio.get_running_loop().create_task(task)))
self_sent_jobs[tail_position] = jobs_at_tail
else:
del self._pending_by_key[event.key]
del self._pending_by_key[event_key]
except BaseException as ex:
if event and event is not _termination_obj and event._awaitable_result:
event._awaitable_result._set_error(ex)
Expand All @@ -726,16 +731,22 @@ async def _do(self, event):
return await self._do_downstream(_termination_obj)
else:
# Initializing the key with 2 lists. One for pending requests and one for requests that an update request has been issued for.
if event.key not in self._pending_by_key:
self._pending_by_key[event.key] = _PendingEvent()
if isinstance(event.key, list):
# list can't be key in a dictionary
event_key = str(event.key)
else:
event_key = event.key

if event_key not in self._pending_by_key:
self._pending_by_key[event_key] = _PendingEvent()

# If there is a current update in flight for the key, add the event to the pending list. Otherwise update the key.
self._pending_by_key[event.key].pending.append(event)
if len(self._pending_by_key[event.key].in_flight) == 0:
self._pending_by_key[event.key].in_flight = self._pending_by_key[event.key].pending
self._pending_by_key[event.key].pending = []
self._pending_by_key[event_key].pending.append(event)
if len(self._pending_by_key[event_key].in_flight) == 0:
self._pending_by_key[event_key].in_flight = self._pending_by_key[event_key].pending
self._pending_by_key[event_key].pending = []

task = self._safe_process_events(self._pending_by_key[event.key].in_flight)
task = self._safe_process_events(self._pending_by_key[event_key].in_flight)
await self._q.put((event, asyncio.get_running_loop().create_task(task)))
if self._worker_awaitable.done():
await self._worker_awaitable
Expand Down
40 changes: 26 additions & 14 deletions storey/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _set_error(self, ex):


class FlowControllerBase:
def __init__(self, key_field: Optional[str], time_field: Optional[str]):
def __init__(self, key_field: Optional[Union[str, List[str]]], time_field: Optional[str]):
self._key_field = key_field
self._time_field = time_field
self._current_uuid_base = None
Expand Down Expand Up @@ -90,12 +90,12 @@ def __init__(self, emit_fn, await_termination_fn, return_awaitable_result, key_f
self._await_termination_fn = await_termination_fn
self._return_awaitable_result = return_awaitable_result

def emit(self, element: object, key: Optional[str] = None, event_time: Optional[datetime] = None,
def emit(self, element: object, key: Optional[Union[str, List[str]]] = None, event_time: Optional[datetime] = None,
return_awaitable_result: Optional[bool] = None):
"""Emits an event into the associated flow.
:param element: The event data, or payload. To set metadata as well, pass an Event object.
:param key: The event key (optional)
:param key: The event key(s) (optional) #add to async
:param event_time: The event time (default to current time, UTC).
:param return_awaitable_result: Deprecated! An awaitable result object will be returned if a Complete step appears in the flow.
Expand Down Expand Up @@ -266,12 +266,12 @@ def __init__(self, emit_fn, loop_task, await_result, key_field: Optional[str] =
self._time_field = time_field
self._await_result = await_result

async def emit(self, element: object, key: Optional[str] = None, event_time: Optional[datetime] = None,
async def emit(self, element: object, key: Optional[Union[str, List[str]]] = None, event_time: Optional[datetime] = None,
await_result: Optional[bool] = None) -> object:
"""Emits an event into the associated flow.
:param element: The event data, or payload. To set metadata as well, pass an Event object.
:param key: The event key (optional)
:param key: The event key(s) (optional)
:param event_time: The event time (default to current time, UTC).
:param await_result: Deprecated. Will await a result if a Complete step appears in the flow.
Expand Down Expand Up @@ -430,7 +430,7 @@ class ReadCSV(_IterableSource):
:parameter build_dict: whether to format each record produced from the input file as a dictionary (as opposed to a list).
Default to False.
:parameter key_field: the CSV field to be use as the key for events. May be an int (field index) or string (field name) if
with_header is True. Defaults to None (no key).
with_header is True. Defaults to None (no key). Can be a list of keys
:parameter time_field: the CSV field to be parsed as the timestamp for events. May be an int (field index) or string (field name)
if with_header is True. Defaults to None (no timestamp field).
:parameter timestamp_format: timestamp format as defined in datetime.strptime(). Default to ISO-8601 as defined in
Expand All @@ -442,7 +442,7 @@ class ReadCSV(_IterableSource):
"""

def __init__(self, paths: Union[List[str], str], header: bool = False, build_dict: bool = False,
key_field: Union[int, str, None] = None, time_field: Union[int, str, None] = None,
key_field: Union[int, str, List[int], List[str], None] = None, time_field: Union[int, str, None] = None,
timestamp_format: Optional[str] = None, type_inference: bool = True, **kwargs):
kwargs['paths'] = paths
kwargs['header'] = header
Expand Down Expand Up @@ -558,10 +558,17 @@ def _blocking_io_loop(self):
for i in range(len(parsed_line)):
element[header[i]] = parsed_line[i]
if self._key_field:
key_field = self._key_field
if self._with_header and isinstance(key_field, str):
key_field = field_name_to_index[key_field]
key = parsed_line[key_field]
if isinstance(self._key_field, list):
key = []
for single_key_field in self._key_field:
if self._with_header and isinstance(single_key_field, str):
single_key_field = field_name_to_index[single_key_field]
key.append(parsed_line[single_key_field])
else:
key_field = self._key_field
if self._with_header and isinstance(key_field, str):
key_field = field_name_to_index[key_field]
key = parsed_line[key_field]
if self._time_field:
time_field = self._time_field
if self._with_header and isinstance(time_field, str):
Expand Down Expand Up @@ -601,14 +608,14 @@ class DataframeSource(_IterableSource):
"""Use pandas dataframe as input source for a flow.
:param dfs: A pandas dataframe, or dataframes, to be used as input source for the flow.
:param key_field: column to be used as key for events.
:param key_field: column to be used as key for events. can be list of columns
:param time_field: column to be used as time for events.
:param id_field: column to be used as ID for events.
for additional params, see documentation of :class:`~storey.flow.Flow`
"""

def __init__(self, dfs: Union[pandas.DataFrame, Iterable[pandas.DataFrame]], key_field: Optional[str] = None,
def __init__(self, dfs: Union[pandas.DataFrame, Iterable[pandas.DataFrame]], key_field: Optional[Union[str, List[str]]] = None,
time_field: Optional[str] = None, id_field: Optional[str] = None, **kwargs):
if key_field is not None:
kwargs['key_field'] = key_field
Expand Down Expand Up @@ -637,7 +644,12 @@ async def _run_loop(self):

key = None
if self._key_field:
key = body[self._key_field]
if isinstance(self._key_field, list):
key = []
for key_field in self._key_field:
key.append(body[key_field])
else:
key = body[self._key_field]
time = None
if self._time_field:
time = body[self._time_field]
Expand Down
Loading

0 comments on commit bb3651d

Please sign in to comment.