Skip to content

Commit

Permalink
fix: Fix templates rendering for different kind of messages
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Aug 24, 2024
1 parent 8ec70ae commit af121a2
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 24 deletions.
71 changes: 60 additions & 11 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,45 @@
from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import ChatSettings, InputChatMLMessage
from ...autogen.openapi_model import (
ChatSettings,
Content,
ContentModel,
InputChatMLMessage,
)
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.utils.template import render_template


def _content_to_dict(
content: str | list[str] | list[Content | ContentModel], role: str
) -> str | list[dict]:
if isinstance(content, str):
return content

result = []
for s in content:
if isinstance(s, str):
result.append({"content": {"type": "text", "text": s, "role": role}})
elif isinstance(s, Content):
result.append({"content": {"type": s.type, "text": s.text, "role": role}})
elif isinstance(s, ContentModel):
result.append(
{
"content": {
"type": s.type,
"image_url": {"url": s.image_url.url},
"role": role,
}
}
)

return result


@activity.defn
@beartype
async def prompt_step(context: StepContext) -> StepOutcome:
Expand All @@ -27,26 +58,44 @@ async def prompt_step(context: StepContext) -> StepOutcome:
template_messages: list[InputChatMLMessage] = prompt
messages = await asyncio.gather(
*[
render_template(msg.content, context_data, skip_vars=["developer_id"])
render_template(
_content_to_dict(msg.content, msg.role),
context_data,
skip_vars=["developer_id"],
)
for msg in template_messages
]
)

messages = [
(
InputChatMLMessage(role="user", content=m)
if isinstance(m, str)
else InputChatMLMessage(**m)
)
for m in messages
]
result_messages = []
for m in messages:
if isinstance(m, str):
msg = InputChatMLMessage(role="user", content=m)
else:
msg = []
for d in m:
role = d["content"].get("role")
d["content"] = [d["content"]]
d["role"] = role
msg.append(InputChatMLMessage(**d))

result_messages.append(msg)

# messages = [
# (
# InputChatMLMessage(role="user", content=m)
# if isinstance(m, str)
# else [InputChatMLMessage(**d) for d in m]
# )
# for m in messages
# ]

# Get settings and run llm
settings: ChatSettings = context.current_step.settings or ChatSettings()
settings_data: dict = settings.model_dump()

response = await litellm.acompletion(
messages=messages,
messages=result_messages,
**settings_data,
)

Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"load_json": json.loads,
}


@beartype
def get_evaluator(names: dict[str, Any]) -> SimpleEval:
evaluator = EvalWithCompoundTypes(names=names, functions=ALLOWED_FUNCTIONS)
Expand Down
6 changes: 5 additions & 1 deletion agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ async def render_template_parts(
# Parse template
# FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow?
templates = [
(jinja_env.from_string(msg["text"]) if msg["type"] == "text" else None)
(
jinja_env.from_string(msg["content"]["text"])
if msg["content"]["type"] == "text"
else None
)
for msg in template_strings
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def start_execution(
client=client,
)

job_id=uuid4()
job_id = uuid4()

try:
handle = await run_task_execution_workflow(
Expand Down
12 changes: 2 additions & 10 deletions agents-api/tests/sample_tasks/test_find_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ async def _(
agent_id = str(agent.id)
task_id = str(uuid4())


with patch_embed_acompletion(), open(
f"{this_dir}/find_selector.yaml", "r"
) as sample_file:
Expand All @@ -106,12 +105,7 @@ async def _(

input = dict(
screenshot_base64="iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA",
network_requests=[{
"request": {},
"response": {
"body": "Lady Gaga"
}
}],
network_requests=[{"request": {}, "response": {"body": "Lady Gaga"}}],
parameters=["name"],
)
execution_data = dict(input=input)
Expand All @@ -122,8 +116,6 @@ async def _(
json=execution_data,
).json()

handle = temporal_client.get_workflow_handle(
execution_created["jobs"][0]
)
handle = temporal_client.get_workflow_handle(execution_created["jobs"][0])

await handle.result()
2 changes: 1 addition & 1 deletion agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ async def _(
) as task_file:
task_definition = yaml.safe_load(task_file)
acompletion.return_value = mock_model_response
data = CreateExecutionRequest(input={"test": "input"})
data = CreateExecutionRequest(input={"parameters": ["param1", "param2"]})

task = create_task(
developer_id=developer_id,
Expand Down

0 comments on commit af121a2

Please sign in to comment.