From 6bcf5a3a6404fad0794598d73d0319ba4e91c772 Mon Sep 17 00:00:00 2001 From: Christopher Lo <46541035+topher-lo@users.noreply.github.com> Date: Mon, 4 Mar 2024 06:05:05 +0000 Subject: [PATCH] feat(engine): Add status column to Action and Workflow tables --- tracecat/api.py | 50 ++++++++++++++++++++++++++++++++++--------------- tracecat/db.py | 2 ++ 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/tracecat/api.py b/tracecat/api.py index 61d018456..21486c136 100644 --- a/tracecat/api.py +++ b/tracecat/api.py @@ -41,6 +41,7 @@ class ActionResponse(BaseModel): id: str title: str description: str + status: str inputs: dict[str, Any] | None @@ -48,6 +49,7 @@ class WorkflowResponse(BaseModel): id: str title: str description: str + status: str actions: dict[str, list[ActionResponse]] graph: dict[str, list[str]] | None # Adjacency list of Action IDs @@ -58,12 +60,14 @@ class ActionMetadataResponse(BaseModel): type: str title: str description: str + status: str class WorkflowMetadataResponse(BaseModel): id: str title: str description: str + status: str ### Workflows @@ -78,7 +82,10 @@ def list_workflows() -> list[WorkflowMetadataResponse]: workflows = results.all() workflow_metadata = [ WorkflowMetadataResponse( - id=workflow.id, title=workflow.title, description=workflow.description + id=workflow.id, + title=workflow.title, + description=workflow.description, + status=workflow.status, ) for workflow in workflows ] @@ -101,7 +108,10 @@ def create_workflow(params: CreateWorkflowParams) -> WorkflowMetadataResponse: session.refresh(workflow) return WorkflowMetadataResponse( - id=workflow.id, title=params.title, description=params.description + id=workflow.id, + title=params.title, + description=params.description, + status=params.status, ) @@ -138,6 +148,7 @@ def get_workflow(workflow_id: str) -> WorkflowResponse: id=workflow.id, title=workflow.title, description=workflow.description, + status=workflow.status, actions=actions_responses, graph=graph, ) @@ -147,6 +158,7 @@ def get_workflow(workflow_id: str) -> WorkflowResponse: class UpdateWorkflowParams(BaseModel): title: str | None = None description: str | None = None + status: str | None = None object: str | None = None @@ -166,6 +178,8 @@ def update_workflow( workflow.title = params.title if params.description is not None: workflow.description = params.description + if params.status is not None: + workflow.status = params.status if params.object is not None: workflow.object = params.object @@ -190,6 +204,7 @@ def list_actions(workflow_id: str) -> list[ActionMetadataResponse]: type=action.type, title=action.title, description=action.description, + status=action.status, ) for action in actions ] @@ -238,29 +253,34 @@ def get_action(action_id: str, workflow_id: int) -> ActionResponse: id=action.id, title=action.title, description=action.description, + status=action.status, inputs=json.loads(action.inputs) if action.inputs else None, ) +class UpdateWorkflowParams(BaseModel): + title: str | None = None + description: str | None = None + status: str | None = None + object: str | None = None + + @app.get("/actions/{action_id}", status_code=204) -def update_action( - action_id: str | None, - title: str | None, - description: str | None, - inputs: str | None, # JSON-serialized string -) -> None: +def update_action(params: UpdateWorkflowParams) -> None: with Session(create_db_engine()) as session: # Fetch the action by id - statement = select(Action).where(Action.id == action_id) + statement = select(Action).where(Action.id == params.action_id) result = session.exec(statement) action = result.one() - if title is not None: - action.title = title - if description is not None: - action.description = description - if inputs is not None: - action.inputs = inputs + if params.title is not None: + action.title = params.title + if params.description is not None: + action.description = params.description + if params.status is not None: + action.status = params.status + if params.inputs is not None: + action.inputs = params.inputs session.add(action) session.commit() diff --git a/tracecat/db.py b/tracecat/db.py index a1c7bd311..8c4c0cb66 100644 --- a/tracecat/db.py +++ b/tracecat/db.py @@ -9,6 +9,7 @@ class Workflow(SQLModel, table=True): id: str | None = Field(default_factory=lambda: uuid4().hex, primary_key=True) title: str description: str + status: str = "offline" # "online" or "offline" object: str | None = None # JSON-serialized String of react flow object actions: list["Action"] | None = Relationship(back_populates="workflow") @@ -18,6 +19,7 @@ class Action(SQLModel, table=True): type: str title: str description: str + status: str = "offline" # "online" or "offline" inputs: str | None = None # JSON-serialized String of inputs workflow_id: str = Field(foreign_key="workflow.id") workflow: Workflow = Relationship(back_populates="actions")