From bb3651ddda3e14d535375760129dd6d13ef46327 Mon Sep 17 00:00:00 2001 From: Katya Katsenelenbogen Date: Thu, 18 Mar 2021 15:45:05 +0200 Subject: [PATCH] add support for multiple key fields (#186) * add support for multiple key fields * PR comments * pr comment --- integration/test_aggregation_integration.py | 50 ++++++++++++++++++++- integration/test_flow_integration.py | 49 +++++++++++++++++++- storey/aggregations.py | 6 ++- storey/dtypes.py | 6 +-- storey/flow.py | 39 ++++++++++------ storey/sources.py | 40 +++++++++++------ storey/table.py | 9 +++- storey/utils.py | 17 +++++++ 8 files changed, 180 insertions(+), 36 deletions(-) diff --git a/integration/test_aggregation_integration.py b/integration/test_aggregation_integration.py index 15a2e700..fbe76e1f 100644 --- a/integration/test_aggregation_integration.py +++ b/integration/test_aggregation_integration.py @@ -2,9 +2,10 @@ from datetime import datetime, timedelta import pytest import math +import pandas as pd from storey import build_flow, Source, Reduce, Table, V3ioDriver, MapWithState, AggregateByKey, FieldAggregator, \ - QueryByKey, WriteToTable, Context + QueryByKey, WriteToTable, Context, DataframeSource from storey.dtypes import SlidingWindows, FixedWindows from storey.utils import _split_path @@ -1129,3 +1130,50 @@ def test_write_to_table_reuse(setup_teardown_test): controller.terminate() actual = controller.await_termination() assert actual == expected_results[iteration] + + +def test_aggregate_multiple_keys(setup_teardown_test): + current_time = pd.Timestamp.now() + data = pd.DataFrame( + { + "first_name": ["moshe", "yosi", "yosi"], + "last_name": ["cohen", "levi", "levi"], + "some_data": [1, 2, 3], + "time": [current_time - pd.Timedelta(minutes=25), current_time - pd.Timedelta(minutes=30), current_time - pd.Timedelta(minutes=35)] + } + ) + + keys = ['first_name', 'last_name'] + table = Table(setup_teardown_test, V3ioDriver()) + controller = build_flow([ + DataframeSource(data, key_field=keys), + AggregateByKey([FieldAggregator("number_of_stuff", "some_data", ["sum"], + SlidingWindows(['1h'], '10m'))], + table), + WriteToTable(table), + ]).run() + + actual = controller.await_termination() + + other_table = Table(setup_teardown_test, V3ioDriver()) + controller = build_flow([ + Source(), + QueryByKey(["number_of_stuff_sum_1h"], + other_table, keys=["first_name", "last_name"]), + Reduce([], lambda acc, x: append_return(acc, x)), + ]).run() + + controller.emit({'first_name': 'moshe', 'last_name': 'cohen', 'some_data': 4}, ['moshe', 'cohen']) + controller.emit({'first_name': 'moshe', 'last_name': 'levi', 'some_data': 5}, ['moshe', 'levi']) + controller.emit({'first_name': 'yosi', 'last_name': 'levi', 'some_data': 6}, ['yosi', 'levi']) + + controller.terminate() + actual = controller.await_termination() + expected_results = [ + {'number_of_stuff_sum_1h': 1.0, 'first_name': 'moshe', 'last_name': 'cohen', 'some_data': 4}, + {'number_of_stuff_sum_1h': 0, 'first_name': 'moshe', 'last_name': 'levi', 'some_data': 5}, + {'number_of_stuff_sum_1h': 5.0, 'first_name': 'yosi', 'last_name': 'levi', 'some_data': 6} + ] + + assert actual == expected_results, \ + f'actual did not match expected. \n actual: {actual} \n expected: {expected_results}' \ No newline at end of file diff --git a/integration/test_flow_integration.py b/integration/test_flow_integration.py index 546f1505..5083b4e0 100644 --- a/integration/test_flow_integration.py +++ b/integration/test_flow_integration.py @@ -11,7 +11,9 @@ import v3io_frames as frames from storey import Filter, JoinWithV3IOTable, SendToHttp, Map, Reduce, Source, HttpRequest, build_flow, \ - WriteToV3IOStream, V3ioDriver, WriteToTSDB, Table, JoinWithTable, MapWithState, WriteToTable, DataframeSource + WriteToV3IOStream, V3ioDriver, WriteToTSDB, Table, JoinWithTable, MapWithState, WriteToTable, DataframeSource, ReduceToDataFrame, \ + QueryByKey, AggregateByKey, ReadCSV +from storey.utils import hash_list from .integration_test_utils import V3ioHeaders, append_return, test_base_time, setup_kv_teardown_test, setup_teardown_test, \ setup_stream_teardown_test @@ -380,6 +382,11 @@ def enrich(event, state): async def get_kv_item(full_path, key): try: + if isinstance(key, list): + if len(key) >= 3: + key = key[0] + "." + str(hash_list(key[1:])) + else: + key = key[0] + "." + key[1] headers = V3ioHeaders() container, path = full_path.split('/', 1) @@ -415,3 +422,43 @@ def test_writing_timedelta_key(setup_teardown_test): ]).run() controller.await_termination() + +def test_write_multiple_keys_to_v3io_from_df(setup_teardown_test): + table = Table(setup_teardown_test, V3ioDriver()) + data = pd.DataFrame( + { + "first_name": ["moshe", "yosi"], + "last_name": ["cohen", "levi"], + "city": ["tel aviv", "ramat gan"], + } + ) + + keys = ['first_name', 'last_name'] + controller = build_flow([ + DataframeSource(data, key_field=keys), + WriteToTable(table), + ]).run() + controller.await_termination() + + response = asyncio.run(get_kv_item(setup_teardown_test, ['moshe', 'cohen'])) + expected = {'city': 'tel aviv', 'first_name': 'moshe', 'last_name': 'cohen'} + assert response.status_code == 200 + assert expected == response.output.item + + +def test_write_multiple_keys_to_v3io_from_csv(setup_teardown_test): + table = Table(setup_teardown_test, V3ioDriver()) + + controller = build_flow([ + ReadCSV('tests/test.csv', header=True, key_field=['n1', 'n2'], build_dict=True), + WriteToTable(table), + ]).run() + controller.await_termination() + + response = asyncio.run(get_kv_item(setup_teardown_test, '1.2')) + + expected = {'n1': 1, 'n2': 2, 'n3': 3} + + assert response.status_code == 200 + assert expected == response.output.item + diff --git a/storey/aggregations.py b/storey/aggregations.py index 39cf52e8..95ccc629 100644 --- a/storey/aggregations.py +++ b/storey/aggregations.py @@ -73,6 +73,8 @@ def f(element, features): self.key_extractor = key elif isinstance(key, str): self.key_extractor = lambda element: element.get(key) + elif isinstance(key, list): + self.key_extractor = lambda element: [element.get(single_key) for single_key in key] else: raise TypeError(f'key is expected to be either a callable or string but got {type(key)}') @@ -218,12 +220,12 @@ class QueryByKey(AggregateByKey): :param table: A Table object or name for persistence of aggregations. If a table name is provided, it will be looked up in the context object passed in kwargs. :param key: Key field to aggregate by, accepts either a string representing the key field or a key extracting function. - Defaults to the key in the event's metadata. (Optional) + Defaults to the key in the event's metadata. (Optional). Can be list of keys :param augmentation_fn: Function that augments the features into the event's body. Defaults to updating a dict. (Optional) :param aliases: Dictionary specifying aliases to the enriched columns, of the format `{'col_name': 'new_col_name'}`. (Optional) """ - def __init__(self, features: List[str], table: Union[Table, str], key: Union[str, Callable[[Event], object], None] = None, + def __init__(self, features: List[str], table: Union[Table, str], key: Union[str, List[str], Callable[[Event], object], None] = None, augmentation_fn: Optional[Callable[[Event, Dict[str, object]], Event]] = None, aliases: Optional[Dict[str, str]] = None, **kwargs): self._aggrs = [] diff --git a/storey/dtypes.py b/storey/dtypes.py index 807bcfd3..0e8c983f 100644 --- a/storey/dtypes.py +++ b/storey/dtypes.py @@ -13,7 +13,7 @@ class Event: """The basic unit of data in storey. All steps receive and emit events. :param body: the event payload, or data - :param key: Event key. Used by steps that aggregate events by key, such as AggregateByKey. (Optional) + :param key: Event key. Used by steps that aggregate events by key, such as AggregateByKey. (Optional). Can be list :param time: Event time. Defaults to the time the event was created, UTC. (Optional) :param id: Event identifier. Usually a unique identifier. (Optional) :param headers: Request headers (HTTP only) (Optional) @@ -24,7 +24,7 @@ class Event: :type awaitable_result: AwaitableResult (Optional) """ - def __init__(self, body: object, key: Optional[str] = None, time: Optional[datetime] = None, id: Optional[str] = None, + def __init__(self, body: object, key: Optional[Union[str, List[str]]] = None, time: Optional[datetime] = None, id: Optional[str] = None, headers: Optional[dict] = None, method: Optional[str] = None, path: Optional[str] = '/', content_type=None, awaitable_result=None): self.body = body @@ -48,7 +48,7 @@ def __eq__(self, other): self.method == other.method and self.path == other.path and self.content_type == other.content_type # noqa: E127 def __str__(self): - return f'Event(id={self.id}, key={self.key}, time={self.time}, body={self.body})' + return f'Event(id={self.id}, key={str(self.key)}, time={self.time}, body={self.body})' def copy(self, body=None, key=None, time=None, id=None, headers=None, method=None, path=None, content_type=None, awaitable_result=None, diff --git a/storey/flow.py b/storey/flow.py index 853888ae..4f1d2e2e 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -685,22 +685,27 @@ async def _worker(self): event = job[0] completed = await job[1] - for event in self._pending_by_key[event.key].in_flight: + if isinstance(event.key, list): + event_key = str(event.key) + else: + event_key = event.key + + for event in self._pending_by_key[event_key].in_flight: await self._handle_completed(event, completed) - self._pending_by_key[event.key].in_flight = [] + self._pending_by_key[event_key].in_flight = [] # If we got more pending events for the same key process them - if self._pending_by_key[event.key].pending: - self._pending_by_key[event.key].in_flight = self._pending_by_key[event.key].pending - self._pending_by_key[event.key].pending = [] + if self._pending_by_key[event_key].pending: + self._pending_by_key[event_key].in_flight = self._pending_by_key[event_key].pending + self._pending_by_key[event_key].pending = [] - task = self._safe_process_events(self._pending_by_key[event.key].in_flight) + task = self._safe_process_events(self._pending_by_key[event_key].in_flight) tail_position = received_job_count + self._q.qsize() jobs_at_tail = self_sent_jobs.get(tail_position, []) jobs_at_tail.append((event, asyncio.get_running_loop().create_task(task))) self_sent_jobs[tail_position] = jobs_at_tail else: - del self._pending_by_key[event.key] + del self._pending_by_key[event_key] except BaseException as ex: if event and event is not _termination_obj and event._awaitable_result: event._awaitable_result._set_error(ex) @@ -726,16 +731,22 @@ async def _do(self, event): return await self._do_downstream(_termination_obj) else: # Initializing the key with 2 lists. One for pending requests and one for requests that an update request has been issued for. - if event.key not in self._pending_by_key: - self._pending_by_key[event.key] = _PendingEvent() + if isinstance(event.key, list): + # list can't be key in a dictionary + event_key = str(event.key) + else: + event_key = event.key + + if event_key not in self._pending_by_key: + self._pending_by_key[event_key] = _PendingEvent() # If there is a current update in flight for the key, add the event to the pending list. Otherwise update the key. - self._pending_by_key[event.key].pending.append(event) - if len(self._pending_by_key[event.key].in_flight) == 0: - self._pending_by_key[event.key].in_flight = self._pending_by_key[event.key].pending - self._pending_by_key[event.key].pending = [] + self._pending_by_key[event_key].pending.append(event) + if len(self._pending_by_key[event_key].in_flight) == 0: + self._pending_by_key[event_key].in_flight = self._pending_by_key[event_key].pending + self._pending_by_key[event_key].pending = [] - task = self._safe_process_events(self._pending_by_key[event.key].in_flight) + task = self._safe_process_events(self._pending_by_key[event_key].in_flight) await self._q.put((event, asyncio.get_running_loop().create_task(task))) if self._worker_awaitable.done(): await self._worker_awaitable diff --git a/storey/sources.py b/storey/sources.py index 6087965c..ae6ef442 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -41,7 +41,7 @@ def _set_error(self, ex): class FlowControllerBase: - def __init__(self, key_field: Optional[str], time_field: Optional[str]): + def __init__(self, key_field: Optional[Union[str, List[str]]], time_field: Optional[str]): self._key_field = key_field self._time_field = time_field self._current_uuid_base = None @@ -90,12 +90,12 @@ def __init__(self, emit_fn, await_termination_fn, return_awaitable_result, key_f self._await_termination_fn = await_termination_fn self._return_awaitable_result = return_awaitable_result - def emit(self, element: object, key: Optional[str] = None, event_time: Optional[datetime] = None, + def emit(self, element: object, key: Optional[Union[str, List[str]]] = None, event_time: Optional[datetime] = None, return_awaitable_result: Optional[bool] = None): """Emits an event into the associated flow. :param element: The event data, or payload. To set metadata as well, pass an Event object. - :param key: The event key (optional) + :param key: The event key(s) (optional) #add to async :param event_time: The event time (default to current time, UTC). :param return_awaitable_result: Deprecated! An awaitable result object will be returned if a Complete step appears in the flow. @@ -266,12 +266,12 @@ def __init__(self, emit_fn, loop_task, await_result, key_field: Optional[str] = self._time_field = time_field self._await_result = await_result - async def emit(self, element: object, key: Optional[str] = None, event_time: Optional[datetime] = None, + async def emit(self, element: object, key: Optional[Union[str, List[str]]] = None, event_time: Optional[datetime] = None, await_result: Optional[bool] = None) -> object: """Emits an event into the associated flow. :param element: The event data, or payload. To set metadata as well, pass an Event object. - :param key: The event key (optional) + :param key: The event key(s) (optional) :param event_time: The event time (default to current time, UTC). :param await_result: Deprecated. Will await a result if a Complete step appears in the flow. @@ -430,7 +430,7 @@ class ReadCSV(_IterableSource): :parameter build_dict: whether to format each record produced from the input file as a dictionary (as opposed to a list). Default to False. :parameter key_field: the CSV field to be use as the key for events. May be an int (field index) or string (field name) if - with_header is True. Defaults to None (no key). + with_header is True. Defaults to None (no key). Can be a list of keys :parameter time_field: the CSV field to be parsed as the timestamp for events. May be an int (field index) or string (field name) if with_header is True. Defaults to None (no timestamp field). :parameter timestamp_format: timestamp format as defined in datetime.strptime(). Default to ISO-8601 as defined in @@ -442,7 +442,7 @@ class ReadCSV(_IterableSource): """ def __init__(self, paths: Union[List[str], str], header: bool = False, build_dict: bool = False, - key_field: Union[int, str, None] = None, time_field: Union[int, str, None] = None, + key_field: Union[int, str, List[int], List[str], None] = None, time_field: Union[int, str, None] = None, timestamp_format: Optional[str] = None, type_inference: bool = True, **kwargs): kwargs['paths'] = paths kwargs['header'] = header @@ -558,10 +558,17 @@ def _blocking_io_loop(self): for i in range(len(parsed_line)): element[header[i]] = parsed_line[i] if self._key_field: - key_field = self._key_field - if self._with_header and isinstance(key_field, str): - key_field = field_name_to_index[key_field] - key = parsed_line[key_field] + if isinstance(self._key_field, list): + key = [] + for single_key_field in self._key_field: + if self._with_header and isinstance(single_key_field, str): + single_key_field = field_name_to_index[single_key_field] + key.append(parsed_line[single_key_field]) + else: + key_field = self._key_field + if self._with_header and isinstance(key_field, str): + key_field = field_name_to_index[key_field] + key = parsed_line[key_field] if self._time_field: time_field = self._time_field if self._with_header and isinstance(time_field, str): @@ -601,14 +608,14 @@ class DataframeSource(_IterableSource): """Use pandas dataframe as input source for a flow. :param dfs: A pandas dataframe, or dataframes, to be used as input source for the flow. - :param key_field: column to be used as key for events. + :param key_field: column to be used as key for events. can be list of columns :param time_field: column to be used as time for events. :param id_field: column to be used as ID for events. for additional params, see documentation of :class:`~storey.flow.Flow` """ - def __init__(self, dfs: Union[pandas.DataFrame, Iterable[pandas.DataFrame]], key_field: Optional[str] = None, + def __init__(self, dfs: Union[pandas.DataFrame, Iterable[pandas.DataFrame]], key_field: Optional[Union[str, List[str]]] = None, time_field: Optional[str] = None, id_field: Optional[str] = None, **kwargs): if key_field is not None: kwargs['key_field'] = key_field @@ -637,7 +644,12 @@ async def _run_loop(self): key = None if self._key_field: - key = body[self._key_field] + if isinstance(self._key_field, list): + key = [] + for key_field in self._key_field: + key.append(body[key_field]) + else: + key = body[self._key_field] time = None if self._time_field: time = body[self._time_field] diff --git a/storey/table.py b/storey/table.py index b71c7093..1bc86c29 100644 --- a/storey/table.py +++ b/storey/table.py @@ -2,7 +2,7 @@ import copy from datetime import datetime from .drivers import Driver -from .utils import _split_path +from .utils import _split_path, get_hashed_key from .dtypes import FieldAggregator, SlidingWindows, FixedWindows from .aggregation_utils import is_raw_aggregate, get_virtual_aggregation_func, get_implied_aggregates, get_all_raw_aggregates, \ get_all_raw_aggregates_with_hidden @@ -39,6 +39,7 @@ def _update_static_attrs(self, key, data): self._set_static_attrs(key, data) async def _lazy_load_key_with_aggregates(self, key, timestamp=None): + key = get_hashed_key(key) if self._aggregations_read_only or not self._get_aggregations_attrs(key): # Try load from the store, and create a new one only if the key really is new aggregate_initial_data, additional_data = await self._storage._load_aggregates_by_key(self._container, self._table_path, key) @@ -51,6 +52,7 @@ async def _lazy_load_key_with_aggregates(self, key, timestamp=None): self._update_static_attrs(key, additional_data) async def _get_or_load_static_attributes_by_key(self, key, attributes='*'): + key = get_hashed_key(key) attrs = self._get_static_attrs(key) if not attrs: res = await self._storage._load_by_key(self._container, self._table_path, key, attributes) @@ -65,6 +67,7 @@ def _set_aggregation_metadata(self, aggregates: List[FieldAggregator], use_windo self._aggregates = aggregates async def _persist_key(self, key, event_data_to_persist): + key = get_hashed_key(key) aggr_by_key = self._get_aggregations_attrs(key) additional_data_persist = self._get_static_attrs(key) if event_data_to_persist: @@ -186,24 +189,28 @@ def _aggregates_to_schema(self): return schema def _get_aggregations_attrs(self, key): + key = get_hashed_key(key) if key in self._attrs_cache: return self._attrs_cache[key].aggregations else: return None def _set_aggregations_attrs(self, key, element): + key = get_hashed_key(key) if key in self._attrs_cache: self._attrs_cache[key].aggregations = element else: self._attrs_cache[key] = _CacheElement({}, element) def _get_static_attrs(self, key): + key = get_hashed_key(key) if key in self._attrs_cache: return self._attrs_cache[key].static_attrs else: return None def _set_static_attrs(self, key, value): + key = get_hashed_key(key) if key in self._attrs_cache: self._attrs_cache[key].static_attrs = value else: diff --git a/storey/utils.py b/storey/utils.py index 72a5d08a..f151fdf0 100644 --- a/storey/utils.py +++ b/storey/utils.py @@ -161,3 +161,20 @@ def update_in(obj, key, value): last_key = parts[-1] obj[last_key] = value + + +def hash_list(list_to_hash): + str_concatted = ''.join(list_to_hash) + hash_value = hash(str_concatted) + return hash_value + + +def get_hashed_key(key_list): + if isinstance(key_list, list): + if len(key_list) >= 3: + return str(key_list[0]) + "." + str(hash_list(key_list[1:])) + if len(key_list) == 2: + return str(key_list[0]) + "." + str(key_list[1]) + return key_list[0] + else: + return key_list