Skip to content

Commit

Permalink
fix(agents-api): Ghost in the machine
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Dec 26, 2024
1 parent 7011246 commit 15f5a89
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 92 deletions.
8 changes: 3 additions & 5 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from datetime import timedelta
from uuid import UUID

Expand Down Expand Up @@ -92,17 +91,16 @@ async def run_task_execution_workflow(
from ..workflows.task_execution import TaskExecutionWorkflow

start: TransitionTarget = start or TransitionTarget(workflow="main", step=0)
previous_inputs: list[dict] = previous_inputs or []

client = client or (await get_client())
execution_id = execution_input.execution.id
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")

# FIXME: This is wrong logic
old_args = execution_input.arguments
execution_input.arguments = await asyncio.gather(
*[offload_if_large(arg) for arg in old_args]
)
execution_input.arguments = await offload_if_large(old_args)

previous_inputs: list[dict] = previous_inputs or [execution_input.arguments]

return await client.start_workflow(
TaskExecutionWorkflow.run,
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:

return dump | execution_input

async def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]:
async def prepare_for_step(self, *args, include_remote=False, **kwargs) -> dict[str, Any]:
# FIXME: include_remote is deprecated

current_input = self.current_input
inputs = self.inputs

Expand Down
7 changes: 7 additions & 0 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CreateTransitionRequest,
Execution,
ResourceCreatedResponse,
TransitionTarget,
)
from ...clients.temporal import run_task_execution_workflow
from ...common.protocol.developers import Developer
Expand Down Expand Up @@ -89,6 +90,12 @@ async def start_execution(
execution_id=execution_id,
data=CreateTransitionRequest(
type="error",
output={"error": str(e)},
current=TransitionTarget(
workflow="main",
step=0,
),
next=None,
),
connection_pool=connection_pool,
)
Expand Down
7 changes: 2 additions & 5 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,15 @@ async def set_last_error(self, value: LastErrorInput):
async def run(
self,
execution_input: ExecutionInput,
start: TransitionTarget = TransitionTarget(workflow="main", step=0),
previous_inputs: list | None = None,
start: TransitionTarget,
previous_inputs: list,
) -> Any:
workflow.logger.info(
f"TaskExecutionWorkflow for task {execution_input.task.id}"
f" [LOC {start.workflow}.{start.step}]"
)

# FIXME: Look into saving arguments to the blob store if necessary
# 0. Prepare context
previous_inputs = previous_inputs or [execution_input.arguments]

context = StepContext(
execution_input=execution_input,
inputs=previous_inputs,
Expand Down
135 changes: 69 additions & 66 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,72 @@
# # Tests for task queries

# import asyncio
# import json
# from unittest.mock import patch

# import yaml
# from google.protobuf.json_format import MessageToDict
# from litellm.types.utils import Choices, ModelResponse
# from ward import raises, skip, test

# from agents_api.autogen.openapi_model import (
# CreateExecutionRequest,
# CreateTaskRequest,
# )
# from agents_api.queries.task.create_task import create_task
# from agents_api.routers.tasks.create_task_execution import start_execution
# from tests.fixtures import (
# cozo_client,
# cozo_clients_with_migrations,
# test_agent,
# test_developer_id,
# )
# from tests.utils import patch_integration_service, patch_testing_temporal

# EMBEDDING_SIZE: int = 1024


# @test("workflow: evaluate step single")
# async def _(
# clients=cozo_clients_with_migrations,
# developer_id=test_developer_id,
# agent=test_agent,
# ):
# client, _ = clients
# data = CreateExecutionRequest(input={"test": "input"})

# task = create_task(
# developer_id=developer_id,
# agent_id=agent.id,
# data=CreateTaskRequest(
# **{
# "name": "test task",
# "description": "test task about",
# "input_schema": {"type": "object", "additionalProperties": True},
# "main": [{"evaluate": {"hello": '"world"'}}],
# }
# ),
# client=client,
# )

# async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
# execution, handle = await start_execution(
# developer_id=developer_id,
# task_id=task.id,
# data=data,
# client=client,
# )

# assert handle is not None
# assert execution.task_id == task.id
# assert execution.input == data.input
# mock_run_task_execution_workflow.assert_called_once()

# result = await handle.result()
# assert result["hello"] == "world"
# Tests for task queries


from ward import test

from agents_api.autogen.openapi_model import (
CreateExecutionRequest,
CreateTaskRequest,
)
from agents_api.clients.pg import create_db_pool
from agents_api.queries.tasks.create_task import create_task
from agents_api.routers.tasks.create_task_execution import start_execution

from .fixtures import (
test_agent,
test_developer_id,
pg_dsn,
client,
s3_client,
)
from .utils import patch_testing_temporal

EMBEDDING_SIZE: int = 1024


@test("workflow: evaluate step single")
async def _(
dsn=pg_dsn,
developer_id=test_developer_id,
agent=test_agent,
_s3_client=s3_client, # Adding coz blob store might be used
_app_client=client,
):
pool = await create_db_pool(dsn=dsn)
data = CreateExecutionRequest(input={"test": "input"})

task = await create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [{"evaluate": {"hello": '"world"'}}],
}
),
connection_pool=pool,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
connection_pool=pool,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input
mock_run_task_execution_workflow.assert_called_once()

try:
result = await handle.result()
assert result["hello"] == "world"
except Exception as ex:
breakpoint()
raise ex


# @test("workflow: evaluate step multiple")
Expand Down
13 changes: 12 additions & 1 deletion memory-store/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,37 @@ services:
# sed -r -i "s/[#]*\s*(shared_preload_libraries)\s*=\s*'(.*)'/\1 = 'pgaudit,\2'/;s/,'/'/" /home/postgres/pgdata/data/postgresql.conf
# && exec /docker-entrypoint.sh

healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres || exit 1"]
interval: 10s
timeout: 5s
retries: 5

vectorizer-worker:
image: timescale/pgai-vectorizer-worker:v0.3.0
environment:
- PGAI_VECTORIZER_WORKER_DB_URL=postgres://postgres:${MEMORY_STORE_PASSWORD:-postgres}@memory-store:5432/postgres
- VOYAGE_API_KEY=${VOYAGE_API_KEY}
command: [ "--poll-interval", "5s" ]
depends_on:
memory-store:
condition: service_healthy

migration:
image: migrate/migrate:latest
volumes:
- ./migrations:/migrations
command: [ "-path", "/migrations", "-database", "postgres://postgres:${MEMORY_STORE_PASSWORD:-postgres}@memory-store:5432/postgres?sslmode=disable" , "up"]

restart: "no"
develop:
watch:
- path: ./migrations
target: ./migrations
action: sync+restart
depends_on:
- memory-store
memory-store:
condition: service_healthy

volumes:
memory_store_data:
Expand Down
4 changes: 2 additions & 2 deletions memory-store/migrations/000012_transitions.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ BEGIN

IF previous_type IS NULL THEN
-- If there is no previous transition, allow only 'init' or 'init_branch'
IF NEW.type NOT IN ('init', 'init_branch') THEN
RAISE EXCEPTION 'First transition must be init or init_branch, got %', NEW.type;
IF NEW.type NOT IN ('init', 'init_branch', 'error', 'cancelled') THEN
RAISE EXCEPTION 'First transition must be init / init_branch / error / cancelled, got %', NEW.type;
END IF;
ELSE
-- Define the valid_next_types array based on previous_type
Expand Down
23 changes: 11 additions & 12 deletions memory-store/migrations/000013_executions_continuous_view.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,23 @@ WITH
SELECT
time_bucket ('1 day', created_at) AS bucket,
execution_id,
transition_id,
last(transition_id, created_at) AS transition_id,
count(*) AS total_transitions,
state_agg (created_at, to_text (type)) AS state,
state_agg(created_at, to_text(type)) AS state,
max(created_at) AS created_at,
last (type, created_at) AS type,
last (step_definition, created_at) AS step_definition,
last (step_label, created_at) AS step_label,
last (current_step, created_at) AS current_step,
last (next_step, created_at) AS next_step,
last (output, created_at) AS output,
last (task_token, created_at) AS task_token,
last (metadata, created_at) AS metadata
last(type, created_at) AS type,
last(step_definition, created_at) AS step_definition,
last(step_label, created_at) AS step_label,
last(current_step, created_at) AS current_step,
last(next_step, created_at) AS next_step,
last(output, created_at) AS output,
last(task_token, created_at) AS task_token,
last(metadata, created_at) AS metadata
FROM
transitions
GROUP BY
bucket,
execution_id,
transition_id
execution_id
WITH
no data;

Expand Down

0 comments on commit 15f5a89

Please sign in to comment.