Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Staging #49

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@
test_workflows = os.getenv("WORKFLOWS")
root_workflows_folder = "../private_workflows" if os.getenv("PRIVATE") else "../workflows"
test_all = True if os.getenv("TEST_ALL") else False
specific_test = os.getenv("SPECIFIC_TEST") if os.getenv("SPECIFIC_TEST") else ""
skip_tests = os.getenv("SKIP_TESTS")

if test_all and specific_test:
print(f"WARNING: can't have both TEST_ALL and SPECIFIC_TEST at the same time...")
print(f"Running TEST_ALL instead")
specific_test = ""

print("========================================")
print(f"db: {db}")
print(f"workspace: {workspace_name}")
print(f"test_workflows: {test_workflows}")
print(f"test_all: {test_all}")
print(f"specific_test: {specific_test}")
print(f"skip_tests: {skip_tests}")
print("========================================")

Expand Down Expand Up @@ -224,6 +231,7 @@ def download_files(force_redownload=False):
modal.Image.debian_slim(python_version="3.11")
.env({"COMFYUI_PATH": "/root", "COMFYUI_MODEL_PATH": "/root/models"})
.env({"TEST_ALL": os.getenv("TEST_ALL")})
.env({"SPECIFIC_TEST": os.getenv("SPECIFIC_TEST")})
.apt_install("git", "git-lfs", "libgl1-mesa-glx", "libglib2.0-0", "libmagic1", "ffmpeg", "libegl1")
.pip_install_from_pyproject(str(root_dir / "pyproject.toml"))
.pip_install("diffusers==0.31.0")
Expand Down Expand Up @@ -319,10 +327,11 @@ def downloads(self):

@modal.build()
def test_workflows(self):
print(" ==== TESTING WORKFLOWS ====")
if os.getenv("SKIP_TESTS"):
print("Skipping tests")
return

print(" ==== TESTING WORKFLOWS ====")

t1 = time.time()
self._start()
Expand All @@ -348,6 +357,8 @@ def test_workflows(self):
test_all = os.getenv("TEST_ALL", False)
if test_all:
tests = glob.glob(f"/root/workspace/workflows/{workflow}/test*.json")
elif specific_test:
tests = [f"/root/workspace/workflows/{workflow}/{specific_test}"]
else:
tests = [f"/root/workspace/workflows/{workflow}/test.json"]
print(f"====> Running tests for {workflow}: ", tests)
Expand Down
7 changes: 6 additions & 1 deletion eve/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import modal
import threading
from fastapi import FastAPI, Depends, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -32,6 +33,7 @@
from eve.agent import Agent
from eve.user import User
from eve.task import Task
from eve.tools.comfyui_tool import convert_tasks2_to_tasks3
from eve import deploy

# Config and logging setup
Expand Down Expand Up @@ -62,7 +64,9 @@
@web_app.on_event("startup")
async def startup_event():
web_app.state.ably_client = AblyRealtime(os.getenv("ABLY_PUBLISHER_KEY"))

watch_thread = threading.Thread(target=convert_tasks2_to_tasks3, daemon=True)
watch_thread.start()
logger.info("Started tasks2 watch thread.")

class TaskRequest(BaseModel):
tool: str
Expand Down Expand Up @@ -141,6 +145,7 @@ async def task_admin(
request: TaskRequest,
_: dict = Depends(auth.authenticate_admin)
):
print("===== this is the request", request)
result = await handle_task(request.tool, request.user_id, request.args)
return serialize_document(result.model_dump(by_alias=True))

Expand Down
3 changes: 2 additions & 1 deletion eve/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def calculate_cost(self, args):
def prepare_args(self, args: dict):
unrecognized_args = set(args.keys()) - set(self.model.model_fields.keys())
if unrecognized_args:
raise ValueError(
# raise ValueError(
print(
f"Unrecognized arguments provided for {self.key}: {', '.join(unrecognized_args)}"
)

Expand Down
46 changes: 27 additions & 19 deletions eve/tools/comfyui_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from bson import ObjectId
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
import asyncio

from ..mongo import get_collection
from ..tool import Tool
Expand Down Expand Up @@ -89,7 +90,6 @@ async def async_run(self, args: Dict):
env=db
)
result = {"output": result}
print(result)
return result

@Tool.handle_start_task
Expand All @@ -114,23 +114,31 @@ async def async_start_task(self, task: Task):
env=db
)
return job.object_id

@Tool.handle_wait
async def async_wait(self, task: Task):
# hack to accommodate legacy comfyui tasks
# 1) get completed task from tasks2 collection
# 2) copy over canonical test
# 3) return new task
fc = modal.functions.FunctionCall.from_id(task.handler_id)
await fc.get.aio()
task.reload()

tasks2 = get_collection("tasks2")
task2 = tasks2.find_one({"_id": task.id})
task.update(
status=task2["status"],
error=task2["error"],
result=task2["result"]
)

return task.model_dump(include={"status", "error", "result"})
def convert_tasks2_to_tasks3():
"""
This is hack to retrofit legacy ComfyUI tasks in tasks2 collection to new tasks3 records
"""
pipeline = [
{
"$match": {
"operationType": {"$in": ["insert", "update", "replace"]}
}
}
]
try:
tasks2 = get_collection("tasks2")
with tasks2.watch(pipeline) as stream:
for change in stream:
task_id = change["documentKey"]["_id"]
update = change["updateDescription"]["updatedFields"]
task = Task.from_mongo(task_id)
task.reload()
task.update(
status=update.get("status", task.status),
error=update.get("error", task.error),
result=update.get("result", task.result)
)
except Exception as e:
print(f"Error in watch_tasks2 thread: {e}")
Loading