From 9a1b9f8fb67bddf852ffdbc3374eafd2d4fbb4e8 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Thu, 9 May 2024 15:45:13 +0800 Subject: [PATCH] Allow user to specify table on a per-event basis in `TDEngineTarget` (#519) * Allow user to specify table on a per-event basis in `TDEngineTarget` [ML-6367](https://iguazio.atlassian.net/browse/ML-6367) * Fix test time zone handling * Allow for user to override `drop_key_field` if they really want to --- integration/test_tdengine.py | 59 ++++++++++++++++++++++-------------- storey/flow.py | 16 ++++++---- storey/targets.py | 30 +++++++++++++++--- 3 files changed, 72 insertions(+), 33 deletions(-) diff --git a/integration/test_tdengine.py b/integration/test_tdengine.py index 43421c25..00f8b749 100644 --- a/integration/test_tdengine.py +++ b/integration/test_tdengine.py @@ -2,6 +2,7 @@ from datetime import datetime import pytest +import pytz import taosrest from taosrest import ConnectError from taosws import QueryError @@ -60,8 +61,9 @@ def tdengine(): connection.close() +@pytest.mark.parametrize("dynamic_table", [None, "$key", "table"]) @pytest.mark.skipif(not has_tdengine_credentials, reason="Missing TDEngine URL, user, and/or password") -def test_tdengine_target(tdengine): +def test_tdengine_target(tdengine, dynamic_table): connection, url, user, password, db_name, table_name, db_prefix = tdengine time_format = "%d/%m/%y %H:%M:%S UTC%z" controller = build_flow( @@ -72,7 +74,8 @@ def test_tdengine_target(tdengine): user=user, password=password, database=db_name, - table=table_name, + table=None if dynamic_table else table_name, + dynamic_table=dynamic_table, time_col="time", columns=["my_int", "my_string"], time_format=time_format, @@ -84,33 +87,43 @@ def test_tdengine_target(tdengine): date_time_str = "18/09/19 01:55:1" for i in range(9): timestamp = f"{date_time_str}{i} UTC-0000" - controller.emit({"time": timestamp, "my_int": i, "my_string": f"hello{i}"}) + event_body = {"time": timestamp, "my_int": i, "my_string": f"hello{i}"} + event_key = None + if dynamic_table == "$key": + event_key = table_name + elif dynamic_table: + event_body[dynamic_table] = table_name + controller.emit(event_body, event_key) controller.terminate() controller.await_termination() result = connection.query(f"SELECT * FROM {db_prefix}{table_name};") - if url.startswith("taosws"): - result_list = [] - for row in result: - row = list(row) - for field_index, field in enumerate(result.fields): - if field.type() == "TIMESTAMP": + result_list = [] + for row in result: + row = list(row) + for field_index, field in enumerate(result.fields): + typ = field.type() if url.startswith("taosws") else field["type"] + if typ == "TIMESTAMP": + if url.startswith("taosws"): t = datetime.fromisoformat(row[field_index]) - # REST API returns a naive timestamp, but websocket returns a timestamp with a time zone - t = t.replace(tzinfo=None) + # websocket returns a timestamp with the local time zone + t = t.astimezone(pytz.UTC).replace(tzinfo=None) row[field_index] = t - result_list.append(row) - else: - result_list = result.data + else: + t = row[field_index] + # REST API returns a naive timestamp matching the local time zone + t = t.astimezone(pytz.UTC).replace(tzinfo=None) + row[field_index] = t + result_list.append(row) assert result_list == [ - [datetime(2019, 9, 18, 9, 55, 10), 0, "hello0"], - [datetime(2019, 9, 18, 9, 55, 11), 1, "hello1"], - [datetime(2019, 9, 18, 9, 55, 12), 2, "hello2"], - [datetime(2019, 9, 18, 9, 55, 13), 3, "hello3"], - [datetime(2019, 9, 18, 9, 55, 14), 4, "hello4"], - [datetime(2019, 9, 18, 9, 55, 15), 5, "hello5"], - [datetime(2019, 9, 18, 9, 55, 16), 6, "hello6"], - [datetime(2019, 9, 18, 9, 55, 17), 7, "hello7"], - [datetime(2019, 9, 18, 9, 55, 18), 8, "hello8"], + [datetime(2019, 9, 18, 1, 55, 10), 0, "hello0"], + [datetime(2019, 9, 18, 1, 55, 11), 1, "hello1"], + [datetime(2019, 9, 18, 1, 55, 12), 2, "hello2"], + [datetime(2019, 9, 18, 1, 55, 13), 3, "hello3"], + [datetime(2019, 9, 18, 1, 55, 14), 4, "hello4"], + [datetime(2019, 9, 18, 1, 55, 15), 5, "hello5"], + [datetime(2019, 9, 18, 1, 55, 16), 6, "hello6"], + [datetime(2019, 9, 18, 1, 55, 17), 7, "hello7"], + [datetime(2019, 9, 18, 1, 55, 18), 8, "hello8"], ] diff --git a/storey/flow.py b/storey/flow.py index 837d1f12..461a8cbb 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -1036,6 +1036,7 @@ def __init__( max_events: Optional[int] = None, flush_after_seconds: Union[int, float, None] = None, key_field: Optional[Union[str, Callable[[Event], str]]] = None, + drop_key_field=False, **kwargs, ): if max_events: @@ -1052,7 +1053,7 @@ def __init__( if self._flush_after_seconds is not None and self._flush_after_seconds < 0: raise ValueError("flush_after_seconds cannot be negative") - self._extract_key: Optional[Callable[[Event], str]] = self._create_key_extractor(key_field) + self._extract_key: Optional[Callable[[Event], str]] = self._create_key_extractor(key_field, drop_key_field) def _init(self): super()._init() @@ -1065,14 +1066,17 @@ def _init(self): self._timeout_task: Optional[Task] = None @staticmethod - def _create_key_extractor(key_field) -> Callable: + def _create_key_extractor(key_field, drop_key_field) -> Callable: if key_field is None: return lambda event: None elif callable(key_field): return key_field elif isinstance(key_field, str): - if key_field == "$key": - return lambda event: event.key + if key_field.startswith("$"): + attribute = key_field[1:] + return lambda event: getattr(event, attribute) + elif drop_key_field: + return lambda event: event.body.pop(key_field) else: return lambda event: event.body[key_field] else: @@ -1159,8 +1163,8 @@ class Batch(_Batching): :param flush_after_seconds: Maximum number of seconds to wait before a batch is emitted. :param key: The key by which events are grouped. By default (None), events are not grouped. Other options may be: - Set a '$key' to group events by the Event.key property. - set a 'str' key to group events by Event.body[str]. + Set to '$x' to group events by the x attribute of the event. E.g. "$key" or "$path". + set to other string 'str' to group events by Event.body[str]. set a Callable[Any, Any] to group events by a a custom key extractor. """ diff --git a/storey/targets.py b/storey/targets.py index 1eb5f132..cbaf8753 100644 --- a/storey/targets.py +++ b/storey/targets.py @@ -783,7 +783,12 @@ class TDEngineTarget(_Batching, _Writer): :param password: Password with which to connect. This is ignored when url is a Websocket URL, which should already contain the password. :param database: Name of the database where events will be written. - :param table: Name of the table in the database where events will be written. + :param table: Name of the table in the database where events will be written. To set the table dynamically on a + per-event basis, use the $ prefix to indicate the field that should be used for the table name, or $$ prefix to + indicate the event attribute (e.g. key or path) that should be used. + :param dynamic_table: Alternative to the table parameter (exactly one of these must be set). The name of the field + in the event body to use for the table, or the name of the event attribute preceded by a dollar sign (e.g. + $key or $path). :param time_col: Name of the time column. :param columns: List of column names to be passed to the DataFrame constructor. Use = notation for renaming fields (e.g. write_this=event_field). Use $ notation to refer to metadata ($key, event_time=$time). @@ -804,7 +809,8 @@ def __init__( user: Optional[str], password: Optional[str], database: Optional[str], - table: str, + table: Optional[str], + dynamic_table: Optional[str], time_col: str, columns: List[str], timeout: Optional[int] = None, @@ -815,6 +821,12 @@ def __init__( if parsed_url.scheme not in ("taosws", "http", "https"): raise ValueError("URL must start with taosws://, http://, or https://") + if table and dynamic_table: + raise ValueError("Cannot set both table and dynamic_table") + + if not table and not dynamic_table: + raise ValueError("table or dynamic_table must be set") + kwargs["url"] = url kwargs["user"] = user kwargs["password"] = password @@ -826,6 +838,14 @@ def __init__( kwargs["timeout"] = timeout if time_format: kwargs["time_format"] = time_format + + self._table = table + + if dynamic_table: + kwargs["key_field"] = dynamic_table + if kwargs.get("drop_key_field") is None: + kwargs["drop_key_field"] = True + _Batching.__init__(self, **kwargs) self._time_col = time_col _Writer.__init__( @@ -841,7 +861,6 @@ def __init__( self._user = user self._password = password self._database = database - self._table = table self._timeout = timeout self._connection = None @@ -879,7 +898,10 @@ async def _emit(self, batch, batch_key, batch_time, batch_events, last_event_tim if not self._using_websocket: b.write(self._database) b.write(".") - b.write(self._table) + if self._table: + b.write(self._table) + else: # table is dynamic + b.write(batch_key) b.write(" VALUES ") for record in batch: b.write("(")