Skip to content

Commit

Permalink
Allow user to specify table on a per-event basis in TDEngineTarget (#…
Browse files Browse the repository at this point in the history
…519)

* Allow user to specify table on a per-event basis in `TDEngineTarget`

[ML-6367](https://iguazio.atlassian.net/browse/ML-6367)

* Fix test time zone handling

* Allow for user to override `drop_key_field` if they really want to
  • Loading branch information
gtopper authored May 9, 2024
1 parent 58ce43c commit 9a1b9f8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 33 deletions.
59 changes: 36 additions & 23 deletions integration/test_tdengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime

import pytest
import pytz
import taosrest
from taosrest import ConnectError
from taosws import QueryError
Expand Down Expand Up @@ -60,8 +61,9 @@ def tdengine():
connection.close()


@pytest.mark.parametrize("dynamic_table", [None, "$key", "table"])
@pytest.mark.skipif(not has_tdengine_credentials, reason="Missing TDEngine URL, user, and/or password")
def test_tdengine_target(tdengine):
def test_tdengine_target(tdengine, dynamic_table):
connection, url, user, password, db_name, table_name, db_prefix = tdengine
time_format = "%d/%m/%y %H:%M:%S UTC%z"
controller = build_flow(
Expand All @@ -72,7 +74,8 @@ def test_tdengine_target(tdengine):
user=user,
password=password,
database=db_name,
table=table_name,
table=None if dynamic_table else table_name,
dynamic_table=dynamic_table,
time_col="time",
columns=["my_int", "my_string"],
time_format=time_format,
Expand All @@ -84,33 +87,43 @@ def test_tdengine_target(tdengine):
date_time_str = "18/09/19 01:55:1"
for i in range(9):
timestamp = f"{date_time_str}{i} UTC-0000"
controller.emit({"time": timestamp, "my_int": i, "my_string": f"hello{i}"})
event_body = {"time": timestamp, "my_int": i, "my_string": f"hello{i}"}
event_key = None
if dynamic_table == "$key":
event_key = table_name
elif dynamic_table:
event_body[dynamic_table] = table_name
controller.emit(event_body, event_key)

controller.terminate()
controller.await_termination()

result = connection.query(f"SELECT * FROM {db_prefix}{table_name};")
if url.startswith("taosws"):
result_list = []
for row in result:
row = list(row)
for field_index, field in enumerate(result.fields):
if field.type() == "TIMESTAMP":
result_list = []
for row in result:
row = list(row)
for field_index, field in enumerate(result.fields):
typ = field.type() if url.startswith("taosws") else field["type"]
if typ == "TIMESTAMP":
if url.startswith("taosws"):
t = datetime.fromisoformat(row[field_index])
# REST API returns a naive timestamp, but websocket returns a timestamp with a time zone
t = t.replace(tzinfo=None)
# websocket returns a timestamp with the local time zone
t = t.astimezone(pytz.UTC).replace(tzinfo=None)
row[field_index] = t
result_list.append(row)
else:
result_list = result.data
else:
t = row[field_index]
# REST API returns a naive timestamp matching the local time zone
t = t.astimezone(pytz.UTC).replace(tzinfo=None)
row[field_index] = t
result_list.append(row)
assert result_list == [
[datetime(2019, 9, 18, 9, 55, 10), 0, "hello0"],
[datetime(2019, 9, 18, 9, 55, 11), 1, "hello1"],
[datetime(2019, 9, 18, 9, 55, 12), 2, "hello2"],
[datetime(2019, 9, 18, 9, 55, 13), 3, "hello3"],
[datetime(2019, 9, 18, 9, 55, 14), 4, "hello4"],
[datetime(2019, 9, 18, 9, 55, 15), 5, "hello5"],
[datetime(2019, 9, 18, 9, 55, 16), 6, "hello6"],
[datetime(2019, 9, 18, 9, 55, 17), 7, "hello7"],
[datetime(2019, 9, 18, 9, 55, 18), 8, "hello8"],
[datetime(2019, 9, 18, 1, 55, 10), 0, "hello0"],
[datetime(2019, 9, 18, 1, 55, 11), 1, "hello1"],
[datetime(2019, 9, 18, 1, 55, 12), 2, "hello2"],
[datetime(2019, 9, 18, 1, 55, 13), 3, "hello3"],
[datetime(2019, 9, 18, 1, 55, 14), 4, "hello4"],
[datetime(2019, 9, 18, 1, 55, 15), 5, "hello5"],
[datetime(2019, 9, 18, 1, 55, 16), 6, "hello6"],
[datetime(2019, 9, 18, 1, 55, 17), 7, "hello7"],
[datetime(2019, 9, 18, 1, 55, 18), 8, "hello8"],
]
16 changes: 10 additions & 6 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ def __init__(
max_events: Optional[int] = None,
flush_after_seconds: Union[int, float, None] = None,
key_field: Optional[Union[str, Callable[[Event], str]]] = None,
drop_key_field=False,
**kwargs,
):
if max_events:
Expand All @@ -1052,7 +1053,7 @@ def __init__(
if self._flush_after_seconds is not None and self._flush_after_seconds < 0:
raise ValueError("flush_after_seconds cannot be negative")

self._extract_key: Optional[Callable[[Event], str]] = self._create_key_extractor(key_field)
self._extract_key: Optional[Callable[[Event], str]] = self._create_key_extractor(key_field, drop_key_field)

def _init(self):
super()._init()
Expand All @@ -1065,14 +1066,17 @@ def _init(self):
self._timeout_task: Optional[Task] = None

@staticmethod
def _create_key_extractor(key_field) -> Callable:
def _create_key_extractor(key_field, drop_key_field) -> Callable:
if key_field is None:
return lambda event: None
elif callable(key_field):
return key_field
elif isinstance(key_field, str):
if key_field == "$key":
return lambda event: event.key
if key_field.startswith("$"):
attribute = key_field[1:]
return lambda event: getattr(event, attribute)
elif drop_key_field:
return lambda event: event.body.pop(key_field)
else:
return lambda event: event.body[key_field]
else:
Expand Down Expand Up @@ -1159,8 +1163,8 @@ class Batch(_Batching):
:param flush_after_seconds: Maximum number of seconds to wait before a batch is emitted.
:param key: The key by which events are grouped. By default (None), events are not grouped.
Other options may be:
Set a '$key' to group events by the Event.key property.
set a 'str' key to group events by Event.body[str].
Set to '$x' to group events by the x attribute of the event. E.g. "$key" or "$path".
set to other string 'str' to group events by Event.body[str].
set a Callable[Any, Any] to group events by a a custom key extractor.
"""

Expand Down
30 changes: 26 additions & 4 deletions storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,12 @@ class TDEngineTarget(_Batching, _Writer):
:param password: Password with which to connect. This is ignored when url is a Websocket URL, which should already
contain the password.
:param database: Name of the database where events will be written.
:param table: Name of the table in the database where events will be written.
:param table: Name of the table in the database where events will be written. To set the table dynamically on a
per-event basis, use the $ prefix to indicate the field that should be used for the table name, or $$ prefix to
indicate the event attribute (e.g. key or path) that should be used.
:param dynamic_table: Alternative to the table parameter (exactly one of these must be set). The name of the field
in the event body to use for the table, or the name of the event attribute preceded by a dollar sign (e.g.
$key or $path).
:param time_col: Name of the time column.
:param columns: List of column names to be passed to the DataFrame constructor. Use = notation for renaming fields
(e.g. write_this=event_field). Use $ notation to refer to metadata ($key, event_time=$time).
Expand All @@ -804,7 +809,8 @@ def __init__(
user: Optional[str],
password: Optional[str],
database: Optional[str],
table: str,
table: Optional[str],
dynamic_table: Optional[str],
time_col: str,
columns: List[str],
timeout: Optional[int] = None,
Expand All @@ -815,6 +821,12 @@ def __init__(
if parsed_url.scheme not in ("taosws", "http", "https"):
raise ValueError("URL must start with taosws://, http://, or https://")

if table and dynamic_table:
raise ValueError("Cannot set both table and dynamic_table")

if not table and not dynamic_table:
raise ValueError("table or dynamic_table must be set")

kwargs["url"] = url
kwargs["user"] = user
kwargs["password"] = password
Expand All @@ -826,6 +838,14 @@ def __init__(
kwargs["timeout"] = timeout
if time_format:
kwargs["time_format"] = time_format

self._table = table

if dynamic_table:
kwargs["key_field"] = dynamic_table
if kwargs.get("drop_key_field") is None:
kwargs["drop_key_field"] = True

_Batching.__init__(self, **kwargs)
self._time_col = time_col
_Writer.__init__(
Expand All @@ -841,7 +861,6 @@ def __init__(
self._user = user
self._password = password
self._database = database
self._table = table
self._timeout = timeout

self._connection = None
Expand Down Expand Up @@ -879,7 +898,10 @@ async def _emit(self, batch, batch_key, batch_time, batch_events, last_event_tim
if not self._using_websocket:
b.write(self._database)
b.write(".")
b.write(self._table)
if self._table:
b.write(self._table)
else: # table is dynamic
b.write(batch_key)
b.write(" VALUES ")
for record in batch:
b.write("(")
Expand Down

0 comments on commit 9a1b9f8

Please sign in to comment.