Skip to content

Commit

Permalink
support connection by user, pass and url
Browse files Browse the repository at this point in the history
  • Loading branch information
Eyal-Danieli committed May 21, 2024
1 parent 3446bd2 commit 082e8d4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
52 changes: 27 additions & 25 deletions integration/test_tdengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,25 @@
from storey.targets import TDEngineTarget

url = os.getenv("TDENGINE_URL")
user = os.getenv("TDENGINE_USER")
password = os.getenv("TDENGINE_PASSWORD")
has_tdengine_credentials = all([url, user, password]) or (url and url.startswith("taosws"))


@pytest.fixture()
def tdengine():
db_name = "storey"
supertable_name = "test_supertable"

connection = taosws.connect(url)
db_prefix = ""
if url.startswith("taosws"):
connection = taosws.connect(url)
else:

connection = taosws.connect(
url=url,
user=user,
password=password,
)

try:
connection.execute(f"DROP DATABASE {db_name};")
Expand All @@ -26,39 +36,35 @@ def tdengine():
raise err

connection.execute(f"CREATE DATABASE {db_name};")

if not db_prefix:
connection.execute(f"USE {db_name}")
connection.execute(f"USE {db_name}")

try:
connection.execute(f"DROP STABLE {db_prefix}{supertable_name};")
connection.execute(f"DROP STABLE {supertable_name};")
except taosws.QueryError as err: # websocket connection raises QueryError
if "STable not exist" not in str(err):
raise err

connection.execute(
f"CREATE STABLE {db_prefix}{supertable_name} (time TIMESTAMP, my_string NCHAR(10)) TAGS (my_int INT);"
)
connection.execute(f"CREATE STABLE {supertable_name} (time TIMESTAMP, my_string NCHAR(10)) TAGS (my_int INT);")

# Test runs
yield connection, url, db_name, supertable_name, db_prefix
yield connection, url, user, password, db_name, supertable_name

# Teardown
connection.execute(f"DROP DATABASE {db_name};")
connection.close()


@pytest.mark.parametrize("table_col", [None, "$key", "table"])
@pytest.mark.skipif(url is None or not url.startswith("taosws"), reason="Missing Valid TDEngine URL")
@pytest.mark.skipif(not has_tdengine_credentials, reason="Missing TDEngine URL, user, and/or password")
def test_tdengine_target(tdengine, table_col):
connection, url, db_name, supertable_name, db_prefix = tdengine
connection, url, user, password, db_name, supertable_name = tdengine
time_format = "%d/%m/%y %H:%M:%S UTC%z"

table_name = "test_table"

# Table is created automatically only when using a supertable
if not table_col:
connection.execute(f"CREATE TABLE {db_prefix}{table_name} (time TIMESTAMP, my_string NCHAR(10), my_int INT);")
connection.execute(f"CREATE TABLE {table_name} (time TIMESTAMP, my_string NCHAR(10), my_int INT);")

controller = build_flow(
[
Expand All @@ -67,6 +73,8 @@ def test_tdengine_target(tdengine, table_col):
url=url,
time_col="time",
columns=["my_string"] if table_col else ["my_string", "my_int"],
user=user,
password=password,
database=db_name,
table=None if table_col else table_name,
table_col=table_col,
Expand Down Expand Up @@ -99,23 +107,17 @@ def test_tdengine_target(tdengine, table_col):
else:
query_table = table_name
where_clause = ""
result = connection.query(f"SELECT * FROM {db_prefix}{query_table} {where_clause} ORDER BY my_int;")
result = connection.query(f"SELECT * FROM {query_table} {where_clause} ORDER BY my_int;")
result_list = []
for row in result:
row = list(row)
for field_index, field in enumerate(result.fields):
typ = field.type() if url.startswith("taosws") else field["type"]
typ = field.type()
if typ == "TIMESTAMP":
if url.startswith("taosws"):
t = datetime.fromisoformat(row[field_index])
# websocket returns a timestamp with the local time zone
t = t.astimezone(pytz.UTC).replace(tzinfo=None)
row[field_index] = t
else:
t = row[field_index]
# REST API returns a naive timestamp matching the local time zone
t = t.astimezone(pytz.UTC).replace(tzinfo=None)
row[field_index] = t
t = datetime.fromisoformat(row[field_index])
# websocket returns a timestamp with the local time zone
t = t.astimezone(pytz.UTC).replace(tzinfo=None)
row[field_index] = t
result_list.append(row)
if table_col:
expected_result = [
Expand Down
18 changes: 14 additions & 4 deletions storey/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,8 @@ class TDEngineTarget(_Batching, _Writer):
:param time_col: Name of the time column.
:param columns: List of column names to be passed to the DataFrame constructor. Use = notation for renaming fields
(e.g. write_this=event_field). Use $ notation to refer to metadata ($key, event_time=$time).
:param user: Username with which to connect.
:param password: Password with which to connect.
:param database: Name of the database where events will be written.
:param table: Name of the table in the database where events will be written. To set the table dynamically on a
per-event basis, use the $ prefix to indicate the field that should be used for the table name, or $$ prefix to
Expand All @@ -807,6 +809,8 @@ def __init__(
url: str,
time_col: str,
columns: List[str],
user: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
table: Optional[str] = None,
table_col: Optional[str] = None,
Expand All @@ -815,8 +819,6 @@ def __init__(
time_format: Optional[str] = None,
**kwargs,
):
if not url.startswith("taosws"):
raise ValueError("URL must start with taosws://")

if table and table_col:
raise ValueError("Cannot set both table and table_col")
Expand All @@ -833,6 +835,10 @@ def __init__(
kwargs["url"] = url
kwargs["time_col"] = time_col
kwargs["columns"] = columns
if user:
kwargs["user"] = user
if password:
kwargs["password"] = password

if database:
kwargs["database"] = database
Expand Down Expand Up @@ -868,15 +874,19 @@ def __init__(
time_format=time_format,
)
self._url = url
self._user = user
self._password = password
self._database = database

def _init(self):
import taosws

_Batching._init(self)
_Writer._init(self)

self._connection = taosws.connect(self._url)
if self._url.startswith("taosws://"):
self._connection = taosws.connect(self._url)
else:
self._connection = taosws.connect(url=self._url, user=self._user, password=self._password)
self._connection.execute(f"USE {self._database}")

def _event_to_batch_entry(self, event):
Expand Down

0 comments on commit 082e8d4

Please sign in to comment.