Skip to content

Commit

Permalink
fix(agents-api): Fixed create update queries and get task
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Dec 26, 2024
1 parent b87f1f8 commit cde3773
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,8 @@ async def transition_step(
transition = await create_execution_transition(
developer_id=context.execution_input.developer_id,
execution_id=context.execution_input.execution.id,
task_id=context.execution_input.task.id,
data=transition_info,
task_token=transition_info.task_token,
update_execution_status=True,
)

except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class StepOutcome(BaseModel):
def task_to_spec(
task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts
) -> TaskSpecDef | PartialTaskSpecDef:
task_data = task.model_dump(**model_opts, exclude={"task_id", "id", "agent_id"})
task_data = task.model_dump(**model_opts, exclude={"version","developer_id", "task_id", "id", "agent_id"})

if "tools" in task_data:
del task_data["tools"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@
$8, -- metadata
$9 -- default_settings
)
ON CONFLICT (developer_id, agent_id) DO UPDATE SET
canonical_name = EXCLUDED.canonical_name,
name = EXCLUDED.name,
about = EXCLUDED.about,
instructions = EXCLUDED.instructions,
model = EXCLUDED.model,
metadata = EXCLUDED.metadata,
default_settings = EXCLUDED.default_settings
RETURNING *;
""").sql(pretty=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
agent_id = (
SELECT agent_id FROM tasks
WHERE developer_id = $1 AND task_id = $2
LIMIT 1
)
LIMIT 1
) a
Expand All @@ -42,27 +43,6 @@
) e
) AS execution;
"""
# (
# SELECT to_jsonb(e) AS execution FROM (
# SELECT * FROM latest_executions
# WHERE
# developer_id = $1 AND
# task_id = $2 AND
# execution_id = $3
# LIMIT 1
# ) e
# ) AS execution;

# (
# SELECT to_jsonb(t) AS task FROM (
# SELECT * FROM tasks
# WHERE
# developer_id = $1 AND
# task_id = $2
# LIMIT 1
# ) t
# ) AS task;


# @rewrap_exceptions(
# {
Expand Down
11 changes: 11 additions & 0 deletions agents-api/agents_api/queries/tasks/create_or_update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
$7, -- description
$8 -- spec
)
ON CONFLICT (agent_id, task_id, name) DO UPDATE SET
type = EXCLUDED.type,
description = EXCLUDED.description,
spec = EXCLUDED.spec
""").sql(pretty=True)

# Define the raw SQL query for creating or updating a task
Expand Down Expand Up @@ -86,6 +90,13 @@
$8::jsonb, -- input_schema
$9::jsonb -- metadata
FROM current_version
ON CONFLICT (developer_id, task_id, "version") DO UPDATE SET
version = tasks.version + 1,
name = EXCLUDED.name,
description = EXCLUDED.description,
inherit_tools = EXCLUDED.inherit_tools,
input_schema = EXCLUDED.input_schema,
metadata = EXCLUDED.metadata
RETURNING *, (SELECT next_version FROM current_version) as next_version;
""").sql(pretty=True)

Expand Down
12 changes: 7 additions & 5 deletions agents-api/agents_api/queries/tasks/get_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...common.protocol.tasks import spec_to_task
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Define the raw SQL query for getting a task
get_task_query = parse_one("""
get_task_query ="""
SELECT
t.*,
COALESCE(
Expand All @@ -23,11 +22,14 @@
END
) FILTER (WHERE w.name IS NOT NULL),
'[]'::jsonb
) as workflows
) as workflows,
jsonb_agg(tl) as tools
FROM
tasks t
LEFT JOIN
INNER JOIN
workflows w ON t.developer_id = w.developer_id AND t.task_id = w.task_id AND t.version = w.version
INNER JOIN
tools tl ON t.developer_id = tl.developer_id AND t.task_id = tl.task_id
WHERE
t.developer_id = $1 AND t.task_id = $2
AND t.version = (
Expand All @@ -36,7 +38,7 @@
WHERE developer_id = $1 AND task_id = $2
)
GROUP BY t.developer_id, t.task_id, t.canonical_name, t.agent_id, t.version;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from temporalio.client import WorkflowHandle
from uuid_extensions import uuid7

from ...common.protocol.tasks import task_to_spec

from ...autogen.openapi_model import (
CreateExecutionRequest,
CreateTransitionRequest,
Expand Down Expand Up @@ -64,13 +66,13 @@ async def start_execution(
connection_pool=connection_pool,
)

execution_input.task = await get_task_query(
task = await get_task_query(
developer_id=developer_id,
task_id=task_id,
connection_pool=connection_pool,
)

execution_input.task.workflows = execution_input.task.main
execution_input.task = task_to_spec(task)

job_id = uuid7()

Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..clients import temporal
from .worker import create_worker
from ..app import lifespan, app

logger = logging.getLogger(__name__)
h = logging.StreamHandler()
Expand All @@ -36,7 +37,8 @@ async def main():
worker = create_worker(client)

# Start the worker to listen for and process tasks
await worker.run()
async with lifespan(app):
await worker.run()


if __name__ == "__main__":
Expand Down

0 comments on commit cde3773

Please sign in to comment.