Skip to content

Commit

Permalink
go
Browse files Browse the repository at this point in the history
  • Loading branch information
genekogan committed Jan 9, 2025
1 parent 6f7478e commit ac52749
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
22 changes: 21 additions & 1 deletion eve/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from . import sentry_sdk
from .tool import Tool
from .task import Creation
from .user import User
from .agent import Agent
from .thread import UserMessage, AssistantMessage, ToolCall, Thread
Expand Down Expand Up @@ -574,15 +575,34 @@ async def async_prompt_thread(
result = await tool.async_wait(task)
thread.update_tool_call(assistant_message.id, t, result)

# yield update
# task completed
if result["status"] == "completed":

# make a Creation
name = task.args.get("prompt") or task.args.get("text_input")
filename = result.get("output", [{}])[0].get("filename")
media_attributes = result.get("output", [{}])[0].get("mediaAttributes")
if filename and media_attributes:
new_creation = Creation(
user=task.user,
requester=task.requester,
task=task.id,
tool=task.tool,
filename=filename,
mediaAttributes=media_attributes,
name=name
)
new_creation.save()

# yield update
yield ThreadUpdate(
type=UpdateType.TOOL_COMPLETE,
tool_name=tool_call.tool,
tool_index=t,
result=result,
)
else:
# yield error
yield ThreadUpdate(
type=UpdateType.ERROR,
tool_name=tool_call.tool,
Expand Down
3 changes: 2 additions & 1 deletion eve/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ async def _task_handler(func, *args, **kwargs):
result = eden_utils.upload_result(result, save_thumbnails=True, save_blurhash=True)

for output in result["output"]:
name = preprocess_result.get("name") or task_args.get("prompt") or args.get("text_input")
# name = preprocess_result.get("name") or task_args.get("prompt") or args.get("text_input")
name = task_args.get("prompt") or args.get("text_input")
if not name:
name = args.get("interpolation_prompts") or args.get("interpolation_texts")
if name:
Expand Down

0 comments on commit ac52749

Please sign in to comment.