diff --git a/tracecat/actions.py b/tracecat/actions.py index 40b6413d5..b78dc3d19 100644 --- a/tracecat/actions.py +++ b/tracecat/actions.py @@ -25,6 +25,7 @@ import random import re from collections.abc import Awaitable, Callable, Iterable +from datetime import UTC, datetime from enum import StrEnum, auto from functools import partial from typing import TYPE_CHECKING, Any, Literal, TypeVar @@ -32,12 +33,14 @@ import httpx import jsonpath_ng +import tantivy from jsonpath_ng.exceptions import JsonPathParserError from pydantic import BaseModel, Field, validator from tenacity import retry, stop_after_attempt, wait_exponential from tracecat.condition import ConditionRuleValidator, ConditionRuleVariant from tracecat.config import MAX_RETRIES, TRACECAT__OAUTH2_GMAIL_PATH +from tracecat.db import create_events_index from tracecat.llm import ( DEFAULT_MODEL_TYPE, ModelType, @@ -446,8 +449,28 @@ async def start_action_run( # Store the result in the action result store. # Every action has its own result and the trail of actions that led to it. # The schema is { : , ...} - action_result_store[ar_id] = action_trail | {ar_id: result} - custom_logger.debug(f"Action run {ar_id!r} completed with result {result}.") + action_trail = action_trail | {ar_id: result} + action_result_store[ar_id] = action_trail + custom_logger.debug( + f"Action run {ar_id!r} completed with trail: {action_trail}." + ) + + # Add trail to events store + writer = create_events_index().writer() + writer.add_document( + tantivy.Document( + # NOTE: Not sure where to get the action metadata from... + action_id=action_ref.id, + action_run_id=ar_id, + action_title=action_ref.title, + action_type=action_ref.type, + workflow_id=workflow_ref.id, + workflow_title=workflow_ref.title, + workflow_run_id=action_run.run_id, + data={ar_id: trail.data for trail in action_trail.items()}, + published_at=datetime.now(UTC).replace(tzinfo=None), + ) + ) if not result.should_continue: custom_logger.info(f"Action run {ar_id!r} stopping due to stop signal.") diff --git a/tracecat/api.py b/tracecat/api.py index c84408972..2c1ae7009 100644 --- a/tracecat/api.py +++ b/tracecat/api.py @@ -4,6 +4,7 @@ from typing import Any, Literal import polars as pl +import tantivy from fastapi import FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -591,21 +592,24 @@ def authenticate_webhook(webhook_id: str, secret: str) -> AuthenticateWebhookRes ### Events Management -class EventParams(BaseModel): - id: str # The action run "key" - workflow_id: str - workflow_run_id: str +class Event(BaseModel): + published_at: datetime action_id: str + action_run_id: str + action_title: str action_type: str - published_at: datetime - event: dict[str, Any] + workflow_id: str + workflow_title: str + workflow_run_id: str + data: dict[str, Any] @app.post("/events") -def index_event(event: EventParams): - with create_events_index().writer() as writer: - writer.add_document(event.model_dump()) - writer.commit() +def index_event(event: Event): + index = create_events_index() + writer = index.writer() + writer.add_document(event.model_dump()) + writer.commit() SUPPORTED_EVENT_AGGS = { @@ -622,29 +626,23 @@ def index_event(event: EventParams): class EventSearchParams(BaseModel): workflow_id: str + limit: int = 1000 + order_by: str = "pubished_at" workflow_run_id: str | None = None query: str | None = None group_by: list[str] | None = None agg: str | None = None -class EventSearchResponse(BaseModel): - id: str - workflow_id: str - workflow_run_id: str - action_id: str - action_type: str - published_at: datetime - event: dict[str, Any] - - @app.get("/events/search") -def search_events(params: EventSearchParams) -> list[EventSearchResponse]: - # Filter by workflow_id - # Filter by workflow_run_id (if non-null) - # Run query - # Return results - pass +def search_events(params: EventSearchParams) -> list[Event]: + index = create_events_index() + index.reload() + query = index.parse_query(params.workflow_id, ["workflow_id"]) + searcher = index.searcher() + searcher.search( + query, order_by_field=tantivy.field(params.order_by), limit=params.limit + ) ### Case Management diff --git a/tracecat/db.py b/tracecat/db.py index dcda07db3..e541bca24 100644 --- a/tracecat/db.py +++ b/tracecat/db.py @@ -83,12 +83,12 @@ def create_db_engine(): return engine -def create_events_index(): +def build_events_index(): index_path = STORAGE_PATH / "event_index" index_path.mkdir(parents=True, exist_ok=True) event_schema = ( tantivy.SchemaBuilder() - .add_date_field("published_at", stored=True) + .add_date_field("published_at", fast=True, stored=True) .add_text_field("action_id", stored=True) .add_text_field("action_run_id", stored=True) .add_text_field("action_title", stored=True) @@ -99,11 +99,15 @@ def create_events_index(): .add_json_field("data", stored=True) .build() ) - index = tantivy.Index(event_schema, path=str(index_path)) - return index + tantivy.Index(event_schema, path=str(index_path)) -def create_vdb_conn(): +def create_events_index() -> tantivy.Index: + index_path = STORAGE_PATH / "event_index" + return tantivy.Index.open(str(index_path)) + + +def create_vdb_conn() -> lancedb.DBConnection: db = lancedb.connect(STORAGE_PATH / "vector.db") return db @@ -146,4 +150,4 @@ def initialize_db() -> None: db.create_table("tasks", schema=TaskSchema, exist_ok=True) # Search - create_events_index() + build_events_index()