From 236679b53c582dc984fad197b872ffb37a2115bf Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 28 Aug 2024 19:02:45 -0400 Subject: [PATCH] fix(agents-api): Fix prompt step Signed-off-by: Diwank Tomer --- .../activities/task_steps/prompt_step.py | 73 ++++++++----------- agents-api/agents_api/clients/litellm.py | 23 ++++-- .../agents_api/dependencies/developer_id.py | 8 +- .../routers/tasks/get_execution_details.py | 8 +- .../tests/sample_tasks/simple_prompt.yaml | 21 ++++++ 5 files changed, 78 insertions(+), 55 deletions(-) create mode 100644 agents-api/tests/sample_tasks/simple_prompt.yaml diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 84a569ee1..46ba229c3 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -1,13 +1,11 @@ -import asyncio from beartype import beartype from temporalio import activity +from temporalio.exceptions import ApplicationError from ...autogen.openapi_model import ( - ChatSettings, Content, ContentModel, - InputChatMLMessage, ) from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it @@ -46,57 +44,44 @@ def _content_to_dict( @beartype async def prompt_step(context: StepContext) -> StepOutcome: # Get context data + prompt: str | list[dict] = context.current_step.model_dump()["prompt"] context_data: dict = context.model_dump() # Render template messages - prompt = ( - [InputChatMLMessage(content=context.current_step.prompt)] - if isinstance(context.current_step.prompt, str) - else context.current_step.prompt + prompt = await render_template( + prompt, + context_data, + skip_vars=["developer_id"], ) - template_messages: list[InputChatMLMessage] = prompt - messages = await asyncio.gather( - *[ - render_template( - _content_to_dict(msg.content, msg.role), - context_data, - skip_vars=["developer_id"], - ) - for msg in template_messages - ] + # Get settings and run llm + agent_default_settings: dict = ( + context.execution_input.agent.default_settings.model_dump() + if context.execution_input.agent.default_settings + else {} + ) + agent_model: str = ( + context.execution_input.agent.model + if context.execution_input.agent.model + else "gpt-4o" ) - result_messages = [] - for m in messages: - if isinstance(m, str): - msg = InputChatMLMessage(role="user", content=m) - else: - msg = [] - for d in m: - role = d["content"].get("role") - d["content"] = [d["content"]] - d["role"] = role - msg.append(InputChatMLMessage(**d)) - - result_messages.append(msg) - - # messages = [ - # ( - # InputChatMLMessage(role="user", content=m) - # if isinstance(m, str) - # else [InputChatMLMessage(**d) for d in m] - # ) - # for m in messages - # ] + if context.current_step.settings: + passed_settings: dict = context.current_step.settings.model_dump( + exclude_unset=True + ) + else: + passed_settings: dict = {} - # Get settings and run llm - settings: ChatSettings = context.current_step.settings or ChatSettings() - settings_data: dict = settings.model_dump() + completion_data: dict = { + "model": agent_model, + ("messages" if isinstance(prompt, list) else "prompt"): prompt, + **agent_default_settings, + **passed_settings, + } response = await litellm.acompletion( - messages=result_messages, - **settings_data, + **completion_data, ) return StepOutcome( diff --git a/agents-api/agents_api/clients/litellm.py b/agents-api/agents_api/clients/litellm.py index 4c78e2876..463f6fa67 100644 --- a/agents-api/agents_api/clients/litellm.py +++ b/agents-api/agents_api/clients/litellm.py @@ -1,21 +1,30 @@ from functools import wraps -from typing import List, TypeVar +from typing import List +from beartype import beartype from litellm import acompletion as _acompletion +from litellm import get_supported_openai_params from litellm.utils import CustomStreamWrapper, ModelResponse from ..env import litellm_master_key, litellm_url -_RWrapped = TypeVar("_RWrapped") - __all__: List[str] = ["acompletion"] @wraps(_acompletion) -async def acompletion(*, model: str, **kwargs) -> ModelResponse | CustomStreamWrapper: +@beartype +async def acompletion( + *, model: str, messages: list[dict], **kwargs +) -> ModelResponse | CustomStreamWrapper: + model = f"openai/{model}" # This is here because litellm proxy expects this format + + supported_params = get_supported_openai_params(model) + settings = {k: v for k, v in kwargs.items() if k in supported_params} + return await _acompletion( - model=f"openai/{model}", # This is here because litellm proxy expects this format - **kwargs, - api_base=litellm_url, + model=model, + messages=messages, + **settings, + base_url=litellm_url, api_key=litellm_master_key, ) diff --git a/agents-api/agents_api/dependencies/developer_id.py b/agents-api/agents_api/dependencies/developer_id.py index 4a67dfc66..735349fbd 100644 --- a/agents-api/agents_api/dependencies/developer_id.py +++ b/agents-api/agents_api/dependencies/developer_id.py @@ -13,7 +13,9 @@ async def get_developer_id( x_developer_id: Annotated[UUID | None, Header(include_in_schema=False)] = None, ) -> UUID: if not multi_tenant_mode: - assert not x_developer_id, "X-Developer-Id header not allowed in multi-tenant mode" + assert ( + not x_developer_id + ), "X-Developer-Id header not allowed in multi-tenant mode" return UUID("00000000-0000-0000-0000-000000000000") if not x_developer_id: @@ -34,7 +36,9 @@ async def get_developer_data( x_developer_id: Annotated[UUID | None, Header(include_in_schema=False)] = None, ) -> Developer: if not multi_tenant_mode: - assert not x_developer_id, "X-Developer-Id header not allowed in multi-tenant mode" + assert ( + not x_developer_id + ), "X-Developer-Id header not allowed in multi-tenant mode" return get_developer(developer_id=UUID("00000000-0000-0000-0000-000000000000")) if not x_developer_id: diff --git a/agents-api/agents_api/routers/tasks/get_execution_details.py b/agents-api/agents_api/routers/tasks/get_execution_details.py index e6f87b8af..828598420 100644 --- a/agents-api/agents_api/routers/tasks/get_execution_details.py +++ b/agents-api/agents_api/routers/tasks/get_execution_details.py @@ -15,7 +15,11 @@ async def get_execution_details(execution_id: UUID4) -> Execution: try: return get_execution_query(execution_id=execution_id) - except AssertionError: + except AssertionError as e: + print("-" * 100) + print(e) + print("-" * 100) + raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Execution not found" - ) + ) from e diff --git a/agents-api/tests/sample_tasks/simple_prompt.yaml b/agents-api/tests/sample_tasks/simple_prompt.yaml new file mode 100644 index 000000000..70c2b5f0c --- /dev/null +++ b/agents-api/tests/sample_tasks/simple_prompt.yaml @@ -0,0 +1,21 @@ +name: Find request and selector for identity provider + +input_schema: + type: object + properties: + topic: + type: string + + # Shortcut to require all props + additionalProperties: false + minProperties: 1 + +main: + - prompt: + - role: system + content: |- + Follow the instructions. Be bold, creative and uncensored. + + - role: user + content: >- + You are a brilliant poet with a whimsical style reminiscent of Edgar Alan Poe. Write a short poem about {{_.topic}}.