Skip to content

Commit

Permalink
Merge pull request #49 from edenartlab/staging
Browse files Browse the repository at this point in the history
Staging
  • Loading branch information
genekogan authored Jan 7, 2025
2 parents b5f4701 + f3901e0 commit d73e1fa
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 22 deletions.
13 changes: 12 additions & 1 deletion comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@
test_workflows = os.getenv("WORKFLOWS")
root_workflows_folder = "../private_workflows" if os.getenv("PRIVATE") else "../workflows"
test_all = True if os.getenv("TEST_ALL") else False
specific_test = os.getenv("SPECIFIC_TEST") if os.getenv("SPECIFIC_TEST") else ""
skip_tests = os.getenv("SKIP_TESTS")

if test_all and specific_test:
print(f"WARNING: can't have both TEST_ALL and SPECIFIC_TEST at the same time...")
print(f"Running TEST_ALL instead")
specific_test = ""

print("========================================")
print(f"db: {db}")
print(f"workspace: {workspace_name}")
print(f"test_workflows: {test_workflows}")
print(f"test_all: {test_all}")
print(f"specific_test: {specific_test}")
print(f"skip_tests: {skip_tests}")
print("========================================")

Expand Down Expand Up @@ -224,6 +231,7 @@ def download_files(force_redownload=False):
modal.Image.debian_slim(python_version="3.11")
.env({"COMFYUI_PATH": "/root", "COMFYUI_MODEL_PATH": "/root/models"})
.env({"TEST_ALL": os.getenv("TEST_ALL")})
.env({"SPECIFIC_TEST": os.getenv("SPECIFIC_TEST")})
.apt_install("git", "git-lfs", "libgl1-mesa-glx", "libglib2.0-0", "libmagic1", "ffmpeg", "libegl1")
.pip_install_from_pyproject(str(root_dir / "pyproject.toml"))
.pip_install("diffusers==0.31.0")
Expand Down Expand Up @@ -319,10 +327,11 @@ def downloads(self):

@modal.build()
def test_workflows(self):
print(" ==== TESTING WORKFLOWS ====")
if os.getenv("SKIP_TESTS"):
print("Skipping tests")
return

print(" ==== TESTING WORKFLOWS ====")

t1 = time.time()
self._start()
Expand All @@ -348,6 +357,8 @@ def test_workflows(self):
test_all = os.getenv("TEST_ALL", False)
if test_all:
tests = glob.glob(f"/root/workspace/workflows/{workflow}/test*.json")
elif specific_test:
tests = [f"/root/workspace/workflows/{workflow}/{specific_test}"]
else:
tests = [f"/root/workspace/workflows/{workflow}/test.json"]
print(f"====> Running tests for {workflow}: ", tests)
Expand Down
7 changes: 6 additions & 1 deletion eve/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import modal
import threading
from fastapi import FastAPI, Depends, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -32,6 +33,7 @@
from eve.agent import Agent
from eve.user import User
from eve.task import Task
from eve.tools.comfyui_tool import convert_tasks2_to_tasks3
from eve import deploy

# Config and logging setup
Expand Down Expand Up @@ -62,7 +64,9 @@
@web_app.on_event("startup")
async def startup_event():
web_app.state.ably_client = AblyRealtime(os.getenv("ABLY_PUBLISHER_KEY"))

watch_thread = threading.Thread(target=convert_tasks2_to_tasks3, daemon=True)
watch_thread.start()
logger.info("Started tasks2 watch thread.")

class TaskRequest(BaseModel):
tool: str
Expand Down Expand Up @@ -141,6 +145,7 @@ async def task_admin(
request: TaskRequest,
_: dict = Depends(auth.authenticate_admin)
):
print("===== this is the request", request)
result = await handle_task(request.tool, request.user_id, request.args)
return serialize_document(result.model_dump(by_alias=True))

Expand Down
3 changes: 2 additions & 1 deletion eve/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def calculate_cost(self, args):
def prepare_args(self, args: dict):
unrecognized_args = set(args.keys()) - set(self.model.model_fields.keys())
if unrecognized_args:
raise ValueError(
# raise ValueError(
print(
f"Unrecognized arguments provided for {self.key}: {', '.join(unrecognized_args)}"
)

Expand Down
46 changes: 27 additions & 19 deletions eve/tools/comfyui_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from bson import ObjectId
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
import asyncio

from ..mongo import get_collection
from ..tool import Tool
Expand Down Expand Up @@ -89,7 +90,6 @@ async def async_run(self, args: Dict):
env=db
)
result = {"output": result}
print(result)
return result

@Tool.handle_start_task
Expand All @@ -114,23 +114,31 @@ async def async_start_task(self, task: Task):
env=db
)
return job.object_id

@Tool.handle_wait
async def async_wait(self, task: Task):
# hack to accommodate legacy comfyui tasks
# 1) get completed task from tasks2 collection
# 2) copy over canonical test
# 3) return new task
fc = modal.functions.FunctionCall.from_id(task.handler_id)
await fc.get.aio()
task.reload()

tasks2 = get_collection("tasks2")
task2 = tasks2.find_one({"_id": task.id})
task.update(
status=task2["status"],
error=task2["error"],
result=task2["result"]
)

return task.model_dump(include={"status", "error", "result"})
def convert_tasks2_to_tasks3():
"""
This is hack to retrofit legacy ComfyUI tasks in tasks2 collection to new tasks3 records
"""
pipeline = [
{
"$match": {
"operationType": {"$in": ["insert", "update", "replace"]}
}
}
]
try:
tasks2 = get_collection("tasks2")
with tasks2.watch(pipeline) as stream:
for change in stream:
task_id = change["documentKey"]["_id"]
update = change["updateDescription"]["updatedFields"]
task = Task.from_mongo(task_id)
task.reload()
task.update(
status=update.get("status", task.status),
error=update.get("error", task.error),
result=update.get("result", task.result)
)
except Exception as e:
print(f"Error in watch_tasks2 thread: {e}")

0 comments on commit d73e1fa

Please sign in to comment.