Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Staging #47

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions eve/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"local",
"modal",
"comfyui",
"comfyui_legacy",
"replicate",
"gcp"
]
Expand Down Expand Up @@ -100,7 +101,7 @@ def _get_schema(cls, key, from_yaml=False) -> dict:
with open(api_file, "r") as f:
schema = yaml.safe_load(f)

if schema.get("handler") == "comfyui":
if schema.get("handler") in ["comfyui", "comfyui_legacy"]:
schema["workspace"] = schema.get("workspace") or api_file.split("/")[-4]
else:
# MongoDB path
Expand All @@ -117,7 +118,7 @@ def get_sub_class(
) -> type:
from .tools.local_tool import LocalTool
from .tools.modal_tool import ModalTool
from .tools.comfyui_tool import ComfyUITool
from .tools.comfyui_tool import ComfyUITool, ComfyUIToolLegacy
from .tools.replicate_tool import ReplicateTool
from .tools.gcp_tool import GCPTool

Expand All @@ -132,6 +133,7 @@ def get_sub_class(
"local": LocalTool,
"modal": ModalTool,
"comfyui": ComfyUITool,
"comfyui_legacy": ComfyUIToolLegacy, # private/legacy workflows
"replicate": ReplicateTool,
"gcp": GCPTool,
None: LocalTool,
Expand Down Expand Up @@ -530,7 +532,7 @@ def get_api_files(root_dir: str = None) -> List[str]:
eve_root = os.path.dirname(os.path.abspath(__file__))
root_dirs = [
os.path.join(eve_root, tools_dir)
for tools_dir in ["tools", "../../workflows"]
for tools_dir in ["tools", "../../workflows", "../../private_workflows"]
]

api_files = {}
Expand Down
67 changes: 66 additions & 1 deletion eve/tools/comfyui_tool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import modal
import os
from bson import ObjectId
from pydantic import BaseModel, Field
from typing import List, Optional, Dict

from ..mongo import get_collection
from ..tool import Tool
from ..task import Task

Expand Down Expand Up @@ -56,7 +58,7 @@ async def async_start_task(self, task: Task):
)
job = await cls().run_task.spawn.aio(task)
return job.object_id

@Tool.handle_wait
async def async_wait(self, task: Task):
fc = modal.functions.FunctionCall.from_id(task.handler_id)
Expand All @@ -69,3 +71,66 @@ async def async_cancel(self, task: Task):
fc = modal.functions.FunctionCall.from_id(task.handler_id)
await fc.cancel.aio()


class ComfyUIToolLegacy(ComfyUITool):
"""For legacy/private workflows"""

@Tool.handle_run
async def async_run(self, args: Dict):
db = os.getenv("DB")
cls = modal.Cls.lookup(
f"comfyui-{self.key}",
"ComfyUI",
environment_name="main"
)
result = await cls().run.remote.aio(
workflow_name=self.key,
args=args,
env=db
)
result = {"output": result}
print(result)
return result

@Tool.handle_start_task
async def async_start_task(self, task: Task):
# hack to accommodate legacy comfyui tasks
# 1) copy task to tasks2 collection (rename tool to workflow)
# 2) spawn new job, env=DB
db = os.getenv("DB")

task_data = task.model_dump(by_alias=True)
task_data["workflow"] = task_data.pop("tool")
tasks2 = get_collection("tasks2")
tasks2.insert_one(task_data)

cls = modal.Cls.lookup(
f"comfyui-{self.key}",
"ComfyUI",
environment_name="main"
)
job = await cls().run_task.spawn.aio(
task_id=ObjectId(task_data["_id"]),
env=db
)
return job.object_id

@Tool.handle_wait
async def async_wait(self, task: Task):
# hack to accommodate legacy comfyui tasks
# 1) get completed task from tasks2 collection
# 2) copy over canonical test
# 3) return new task
fc = modal.functions.FunctionCall.from_id(task.handler_id)
await fc.get.aio()
task.reload()

tasks2 = get_collection("tasks2")
task2 = tasks2.find_one({"_id": task.id})
task.update(
status=task2["status"],
error=task2["error"],
result=task2["result"]
)

return task.model_dump(include={"status", "error", "result"})
Loading