Skip to content

Commit

Permalink
feat(engine): Implement trail indexing and search
Browse files Browse the repository at this point in the history
  • Loading branch information
topher-lo committed Mar 9, 2024
1 parent 8ab9bce commit 50623a8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 34 deletions.
27 changes: 25 additions & 2 deletions tracecat/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,22 @@
import random
import re
from collections.abc import Awaitable, Callable, Iterable
from datetime import UTC, datetime
from enum import StrEnum, auto
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from uuid import uuid4

import httpx
import jsonpath_ng
import tantivy
from jsonpath_ng.exceptions import JsonPathParserError
from pydantic import BaseModel, Field, validator
from tenacity import retry, stop_after_attempt, wait_exponential

from tracecat.condition import ConditionRuleValidator, ConditionRuleVariant
from tracecat.config import MAX_RETRIES, TRACECAT__OAUTH2_GMAIL_PATH
from tracecat.db import create_events_index
from tracecat.llm import (
DEFAULT_MODEL_TYPE,
ModelType,
Expand Down Expand Up @@ -446,8 +449,28 @@ async def start_action_run(
# Store the result in the action result store.
# Every action has its own result and the trail of actions that led to it.
# The schema is {<action ID> : <action result>, ...}
action_result_store[ar_id] = action_trail | {ar_id: result}
custom_logger.debug(f"Action run {ar_id!r} completed with result {result}.")
action_trail = action_trail | {ar_id: result}
action_result_store[ar_id] = action_trail
custom_logger.debug(
f"Action run {ar_id!r} completed with trail: {action_trail}."
)

# Add trail to events store
writer = create_events_index().writer()
writer.add_document(
tantivy.Document(
# NOTE: Not sure where to get the action metadata from...
action_id=action_ref.id,
action_run_id=ar_id,
action_title=action_ref.title,
action_type=action_ref.type,
workflow_id=workflow_ref.id,
workflow_title=workflow_ref.title,
workflow_run_id=action_run.run_id,
data={ar_id: trail.data for trail in action_trail.items()},
published_at=datetime.now(UTC).replace(tzinfo=None),
)
)

if not result.should_continue:
custom_logger.info(f"Action run {ar_id!r} stopping due to stop signal.")
Expand Down
50 changes: 24 additions & 26 deletions tracecat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Literal

import polars as pl
import tantivy
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
Expand Down Expand Up @@ -591,21 +592,24 @@ def authenticate_webhook(webhook_id: str, secret: str) -> AuthenticateWebhookRes
### Events Management


class EventParams(BaseModel):
id: str # The action run "key"
workflow_id: str
workflow_run_id: str
class Event(BaseModel):
published_at: datetime
action_id: str
action_run_id: str
action_title: str
action_type: str
published_at: datetime
event: dict[str, Any]
workflow_id: str
workflow_title: str
workflow_run_id: str
data: dict[str, Any]


@app.post("/events")
def index_event(event: EventParams):
with create_events_index().writer() as writer:
writer.add_document(event.model_dump())
writer.commit()
def index_event(event: Event):
index = create_events_index()
writer = index.writer()
writer.add_document(event.model_dump())
writer.commit()


SUPPORTED_EVENT_AGGS = {
Expand All @@ -622,29 +626,23 @@ def index_event(event: EventParams):

class EventSearchParams(BaseModel):
workflow_id: str
limit: int = 1000
order_by: str = "pubished_at"
workflow_run_id: str | None = None
query: str | None = None
group_by: list[str] | None = None
agg: str | None = None


class EventSearchResponse(BaseModel):
id: str
workflow_id: str
workflow_run_id: str
action_id: str
action_type: str
published_at: datetime
event: dict[str, Any]


@app.get("/events/search")
def search_events(params: EventSearchParams) -> list[EventSearchResponse]:
# Filter by workflow_id
# Filter by workflow_run_id (if non-null)
# Run query
# Return results
pass
def search_events(params: EventSearchParams) -> list[Event]:
index = create_events_index()
index.reload()
query = index.parse_query(params.workflow_id, ["workflow_id"])
searcher = index.searcher()
searcher.search(
query, order_by_field=tantivy.field(params.order_by), limit=params.limit
)


### Case Management
Expand Down
16 changes: 10 additions & 6 deletions tracecat/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def create_db_engine():
return engine


def create_events_index():
def build_events_index():
index_path = STORAGE_PATH / "event_index"
index_path.mkdir(parents=True, exist_ok=True)
event_schema = (
tantivy.SchemaBuilder()
.add_date_field("published_at", stored=True)
.add_date_field("published_at", fast=True, stored=True)
.add_text_field("action_id", stored=True)
.add_text_field("action_run_id", stored=True)
.add_text_field("action_title", stored=True)
Expand All @@ -99,11 +99,15 @@ def create_events_index():
.add_json_field("data", stored=True)
.build()
)
index = tantivy.Index(event_schema, path=str(index_path))
return index
tantivy.Index(event_schema, path=str(index_path))


def create_vdb_conn():
def create_events_index() -> tantivy.Index:
index_path = STORAGE_PATH / "event_index"
return tantivy.Index.open(str(index_path))


def create_vdb_conn() -> lancedb.DBConnection:
db = lancedb.connect(STORAGE_PATH / "vector.db")
return db

Expand Down Expand Up @@ -146,4 +150,4 @@ def initialize_db() -> None:
db.create_table("tasks", schema=TaskSchema, exist_ok=True)

# Search
create_events_index()
build_events_index()

0 comments on commit 50623a8

Please sign in to comment.