From ac527494455874e654231aff5f4eb137f977f23b Mon Sep 17 00:00:00 2001 From: genekogan Date: Thu, 9 Jan 2025 01:19:36 -0800 Subject: [PATCH] go --- eve/llm.py | 22 +++++++++++++++++++++- eve/task.py | 3 ++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/eve/llm.py b/eve/llm.py index 7172c32..b5d112f 100644 --- a/eve/llm.py +++ b/eve/llm.py @@ -17,6 +17,7 @@ from . import sentry_sdk from .tool import Tool +from .task import Creation from .user import User from .agent import Agent from .thread import UserMessage, AssistantMessage, ToolCall, Thread @@ -574,8 +575,26 @@ async def async_prompt_thread( result = await tool.async_wait(task) thread.update_tool_call(assistant_message.id, t, result) - # yield update + # task completed if result["status"] == "completed": + + # make a Creation + name = task.args.get("prompt") or task.args.get("text_input") + filename = result.get("output", [{}])[0].get("filename") + media_attributes = result.get("output", [{}])[0].get("mediaAttributes") + if filename and media_attributes: + new_creation = Creation( + user=task.user, + requester=task.requester, + task=task.id, + tool=task.tool, + filename=filename, + mediaAttributes=media_attributes, + name=name + ) + new_creation.save() + + # yield update yield ThreadUpdate( type=UpdateType.TOOL_COMPLETE, tool_name=tool_call.tool, @@ -583,6 +602,7 @@ async def async_prompt_thread( result=result, ) else: + # yield error yield ThreadUpdate( type=UpdateType.ERROR, tool_name=tool_call.tool, diff --git a/eve/task.py b/eve/task.py index b5b5a68..b1e1f4c 100644 --- a/eve/task.py +++ b/eve/task.py @@ -140,7 +140,8 @@ async def _task_handler(func, *args, **kwargs): result = eden_utils.upload_result(result, save_thumbnails=True, save_blurhash=True) for output in result["output"]: - name = preprocess_result.get("name") or task_args.get("prompt") or args.get("text_input") + # name = preprocess_result.get("name") or task_args.get("prompt") or args.get("text_input") + name = task_args.get("prompt") or args.get("text_input") if not name: name = args.get("interpolation_prompts") or args.get("interpolation_texts") if name: