From 8ab9bcefb28e240e9c07fe9425c75c756f79085c Mon Sep 17 00:00:00 2001 From: Christopher Lo <46541035+topher-lo@users.noreply.github.com> Date: Sat, 9 Mar 2024 11:18:55 +0000 Subject: [PATCH] fix(engine): AI db schemas --- tracecat/db.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/tracecat/db.py b/tracecat/db.py index 0987c13ab..dcda07db3 100644 --- a/tracecat/db.py +++ b/tracecat/db.py @@ -84,19 +84,22 @@ def create_db_engine(): def create_events_index(): - index_path = STORAGE_PATH / "events_index" + index_path = STORAGE_PATH / "event_index" index_path.mkdir(parents=True, exist_ok=True) event_schema = ( - tantivy.Schema() - .add_string_field("id", stored=True) - .add_string_field("workflow_id", stored=True) - .add_string_field("workflow_run_id", stored=True) - .add_string_field("action_id", stored=True) - .add_string_field("action_type", stored=True) + tantivy.SchemaBuilder() .add_date_field("published_at", stored=True) - .add_json_field("event", stored=True) + .add_text_field("action_id", stored=True) + .add_text_field("action_run_id", stored=True) + .add_text_field("action_title", stored=True) + .add_text_field("action_type", stored=True) + .add_text_field("workflow_id", stored=True) + .add_text_field("workflow_title", stored=True) + .add_text_field("workflow_run_id", stored=True) + .add_json_field("data", stored=True) + .build() ) - index = tantivy.Index(event_schema, path=index_path) + index = tantivy.Index(event_schema, path=str(index_path)) return index @@ -110,10 +113,10 @@ def create_vdb_conn(): pa.field("id", pa.int64(), nullable=False), pa.field("workflow_id", pa.int64(), nullable=False), pa.field("title", pa.string(), nullable=False), - pa.field("payload", pa.dictionary(pa.string(), pa.string()), nullable=False), + pa.field("payload", pa.string(), nullable=False), # JSON-serialized pa.field("malice", pa.string(), nullable=False), - pa.field("context", pa.dictionary(pa.string(), pa.string()), nullable=False), - pa.field("suppression", pa.dictionary(pa.string(), pa.bool_()), nullable=False), + pa.field("context", pa.string(), nullable=False), # JSON-serialized + pa.field("suppression", pa.string(), nullable=False), # JSON-serialized pa.field("status", pa.string(), nullable=False), pa.field("priority", pa.string(), nullable=False), pa.field("_payload_vector", pa.list_(pa.float32(), list_size=EMBEDDINGS_SIZE)), @@ -139,8 +142,8 @@ def initialize_db() -> None: # VectorDB db = create_vdb_conn() - db.create_table("cases", schema=CaseSchema) - db.create_table("tasks", schema=TaskSchema) + db.create_table("cases", schema=CaseSchema, exist_ok=True) + db.create_table("tasks", schema=TaskSchema, exist_ok=True) # Search create_events_index()