Skip to content

Commit

Permalink
Merge pull request #994 from julep-ai/x/fix-temporal-workflows
Browse files Browse the repository at this point in the history
fix(agents-api, memory-store): fix temporal workflows, queries, and migrations
  • Loading branch information
whiterabbit1983 authored Dec 26, 2024
2 parents 4e1c975 + b5a8313 commit 9b9bd73
Show file tree
Hide file tree
Showing 20 changed files with 241 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ async def transition_step(
transition = await create_execution_transition(
developer_id=context.execution_input.developer_id,
execution_id=context.execution_input.execution.id,
task_id=context.execution_input.task.id,
data=transition_info,
task_token=transition_info.task_token,
update_execution_status=True,
)

except Exception as e:
Expand Down
9 changes: 4 additions & 5 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from datetime import timedelta
from uuid import UUID

Expand Down Expand Up @@ -92,16 +91,16 @@ async def run_task_execution_workflow(
from ..workflows.task_execution import TaskExecutionWorkflow

start: TransitionTarget = start or TransitionTarget(workflow="main", step=0)
previous_inputs: list[dict] = previous_inputs or []

client = client or (await get_client())
execution_id = execution_input.execution.id
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")

# FIXME: This is wrong logic
old_args = execution_input.arguments
execution_input.arguments = await asyncio.gather(
*[offload_if_large(arg) for arg in old_args]
)
execution_input.arguments = await offload_if_large(old_args)

previous_inputs: list[dict] = previous_inputs or [execution_input.arguments]

return await client.start_workflow(
TaskExecutionWorkflow.run,
Expand Down
16 changes: 12 additions & 4 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class PartialTransition(create_partial_model(CreateTransitionRequest)):
class ExecutionInput(BaseModel):
developer_id: UUID
execution: Execution | None = None
task: TaskSpecDef
task: TaskSpecDef | None = None
agent: Agent
agent_tools: list[Tool | CreateToolRequest]
arguments: dict[str, Any]
Expand Down Expand Up @@ -239,7 +239,11 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:

return dump | execution_input

async def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]:
async def prepare_for_step(
self, *args, include_remote=False, **kwargs
) -> dict[str, Any]:
# FIXME: include_remote is deprecated

current_input = self.current_input
inputs = self.inputs

Expand All @@ -266,7 +270,9 @@ class StepOutcome(BaseModel):
def task_to_spec(
task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts
) -> TaskSpecDef | PartialTaskSpecDef:
task_data = task.model_dump(**model_opts, exclude={"task_id", "id", "agent_id"})
task_data = task.model_dump(
**model_opts, exclude={"version", "developer_id", "task_id", "id", "agent_id"}
)

if "tools" in task_data:
del task_data["tools"]
Expand Down Expand Up @@ -305,7 +311,9 @@ def spec_to_task_data(spec: dict) -> dict:
workflows_dict = {workflow["name"]: workflow["steps"] for workflow in workflows}

tools = spec.pop("tools", [])
tools = [{tool["type"]: tool.pop("spec"), **tool} for tool in tools]
tools = [
{tool["type"]: tool.pop("spec"), **tool} for tool in tools if tool is not None
]

return {
"id": task_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@
$8, -- metadata
$9 -- default_settings
)
ON CONFLICT (developer_id, agent_id) DO UPDATE SET
canonical_name = EXCLUDED.canonical_name,
name = EXCLUDED.name,
about = EXCLUDED.about,
instructions = EXCLUDED.instructions,
model = EXCLUDED.model,
metadata = EXCLUDED.metadata,
default_settings = EXCLUDED.default_settings
RETURNING *;
""").sql(pretty=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ..utils import (
partialclass,
Expand Down
48 changes: 26 additions & 22 deletions agents-api/agents_api/queries/executions/prepare_execution_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,30 @@
agent_id = (
SELECT agent_id FROM tasks
WHERE developer_id = $1 AND task_id = $2
LIMIT 1
)
LIMIT 1
) a
) AS agent,
(
SELECT jsonb_agg(r) AS tools FROM (
SELECT COALESCE(jsonb_agg(r), '[]'::jsonb) AS tools FROM (
SELECT * FROM tools
WHERE
developer_id = $1 AND
task_id = $2
) r
) AS tools,
(
SELECT to_jsonb(t) AS task FROM (
SELECT * FROM tasks
SELECT to_jsonb(e) AS execution FROM (
SELECT * FROM latest_executions
WHERE
developer_id = $1 AND
task_id = $2
developer_id = $1 AND
task_id = $2 AND
execution_id = $3
LIMIT 1
) t
) AS task;
) e
) AS execution;
"""
# (
# SELECT to_jsonb(e) AS execution FROM (
# SELECT * FROM latest_executions
# WHERE
# developer_id = $1 AND
# task_id = $2 AND
# execution_id = $3
# LIMIT 1
# ) e
# ) AS execution;


# @rewrap_exceptions(
Expand All @@ -70,13 +62,25 @@
one=True,
transform=lambda d: {
**d,
"task": {
"tools": d["tools"],
**d["task"],
# "task": {
# "tools": d["tools"],
# **d["task"],
# },
"developer_id": d["agent"]["developer_id"],
"agent": {
"id": d["agent"]["agent_id"],
**d["agent"],
},
"agent_tools": [
{tool["type"]: tool.pop("spec"), **tool} for tool in d["tools"]
{tool["type"]: tool.pop("spec"), **tool}
for tool in d["tools"]
if tool is not None
],
"arguments": d["execution"]["input"],
"execution": {
"id": d["execution"]["execution_id"],
**d["execution"],
},
},
)
@pg_query
Expand All @@ -103,6 +107,6 @@ async def prepare_execution_input(
[
str(developer_id),
str(task_id),
# str(execution_id),
str(execution_id),
],
)
11 changes: 11 additions & 0 deletions agents-api/agents_api/queries/tasks/create_or_update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
$7, -- description
$8 -- spec
)
ON CONFLICT (agent_id, task_id, name) DO UPDATE SET
type = EXCLUDED.type,
description = EXCLUDED.description,
spec = EXCLUDED.spec
""").sql(pretty=True)

# Define the raw SQL query for creating or updating a task
Expand Down Expand Up @@ -86,6 +90,13 @@
$8::jsonb, -- input_schema
$9::jsonb -- metadata
FROM current_version
ON CONFLICT (developer_id, task_id, "version") DO UPDATE SET
version = tasks.version + 1,
name = EXCLUDED.name,
description = EXCLUDED.description,
inherit_tools = EXCLUDED.inherit_tools,
input_schema = EXCLUDED.input_schema,
metadata = EXCLUDED.metadata
RETURNING *, (SELECT next_version FROM current_version) as next_version;
""").sql(pretty=True)

Expand Down
32 changes: 21 additions & 11 deletions agents-api/agents_api/queries/tasks/get_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,40 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...common.protocol.tasks import spec_to_task
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Define the raw SQL query for getting a task
get_task_query = parse_one("""
get_task_query = """
SELECT
t.*,
COALESCE(
jsonb_agg(
CASE WHEN w.name IS NOT NULL THEN
jsonb_build_object(
'name', w.name,
'steps', jsonb_build_array(w.step_definition)
jsonb_agg(
DISTINCT jsonb_build_object(
'name', w.name,
'steps', (
SELECT jsonb_agg(step_definition ORDER BY step_idx)
FROM workflows w2
WHERE w2.developer_id = w.developer_id
AND w2.task_id = w.task_id
AND w2.version = w.version
AND w2.name = w.name
)
END
) FILTER (WHERE w.name IS NOT NULL),
)
) FILTER (WHERE w.name IS NOT NULL),
'[]'::jsonb
) as workflows
) as workflows,
COALESCE(
jsonb_agg(tl) FILTER (WHERE tl IS NOT NULL),
'[]'::jsonb
) as tools
FROM
tasks t
LEFT JOIN
workflows w ON t.developer_id = w.developer_id AND t.task_id = w.task_id AND t.version = w.version
LEFT JOIN
tools tl ON t.developer_id = tl.developer_id AND t.task_id = tl.task_id
WHERE
t.developer_id = $1 AND t.task_id = $2
AND t.version = (
Expand All @@ -36,7 +46,7 @@
WHERE developer_id = $1 AND task_id = $2
)
GROUP BY t.developer_id, t.task_id, t.canonical_name, t.agent_id, t.version;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from uuid import UUID

from beartype import beartype
from sqlglot import parse_one

from ..utils import (
pg_query,
wrap_in_class,
)

# Define the raw SQL query for getting tool args from metadata
tools_args_for_task_query = parse_one("""
tools_args_for_task_query = """
SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM (
SELECT
CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args'
Expand All @@ -28,10 +27,11 @@
WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md
FROM tasks
WHERE task_id = $2 AND developer_id = $4 LIMIT 1
) AS tasks_md""").sql(pretty=True)
) AS tasks_md"""

# Define the raw SQL query for getting tool args from metadata for a session
tool_args_for_session_query = parse_one("""SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM (
tool_args_for_session_query = """
SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM (
SELECT
CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args'
WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args'
Expand All @@ -48,7 +48,7 @@
WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md
FROM sessions
WHERE session_id = $2 AND developer_id = $4 LIMIT 1
) AS sessions_md""").sql(pretty=True)
) AS sessions_md"""


# @rewrap_exceptions(
Expand Down
16 changes: 16 additions & 0 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
CreateTransitionRequest,
Execution,
ResourceCreatedResponse,
TransitionTarget,
)
from ...clients.temporal import run_task_execution_workflow
from ...common.protocol.developers import Developer
from ...common.protocol.tasks import task_to_spec
from ...dependencies.developer_id import get_developer_id
from ...env import max_free_executions
from ...queries.developers.get_developer import get_developer
Expand Down Expand Up @@ -64,6 +66,14 @@ async def start_execution(
connection_pool=connection_pool,
)

task = await get_task_query(
developer_id=developer_id,
task_id=task_id,
connection_pool=connection_pool,
)

execution_input.task = task_to_spec(task)

job_id = uuid7()

try:
Expand All @@ -80,6 +90,12 @@ async def start_execution(
execution_id=execution_id,
data=CreateTransitionRequest(
type="error",
output={"error": str(e)},
current=TransitionTarget(
workflow="main",
step=0,
),
next=None,
),
connection_pool=connection_pool,
)
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from tenacity import after_log, retry, retry_if_exception_type, wait_fixed

from ..app import app, lifespan
from ..clients import temporal
from .worker import create_worker

Expand All @@ -36,7 +37,8 @@ async def main():
worker = create_worker(client)

# Start the worker to listen for and process tasks
await worker.run()
async with lifespan(app):
await worker.run()


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,15 @@ async def set_last_error(self, value: LastErrorInput):
async def run(
self,
execution_input: ExecutionInput,
start: TransitionTarget = TransitionTarget(workflow="main", step=0),
previous_inputs: list | None = None,
start: TransitionTarget,
previous_inputs: list,
) -> Any:
workflow.logger.info(
f"TaskExecutionWorkflow for task {execution_input.task.id}"
f" [LOC {start.workflow}.{start.step}]"
)

# FIXME: Look into saving arguments to the blob store if necessary
# 0. Prepare context
previous_inputs = previous_inputs or [execution_input.arguments]

context = StepContext(
execution_input=execution_input,
inputs=previous_inputs,
Expand Down
Loading

0 comments on commit 9b9bd73

Please sign in to comment.