diff --git a/comfyui.py b/comfyui.py index 425fb55..84b43dc 100644 --- a/comfyui.py +++ b/comfyui.py @@ -63,13 +63,20 @@ test_workflows = os.getenv("WORKFLOWS") root_workflows_folder = "../private_workflows" if os.getenv("PRIVATE") else "../workflows" test_all = True if os.getenv("TEST_ALL") else False +specific_test = os.getenv("SPECIFIC_TEST") if os.getenv("SPECIFIC_TEST") else "" skip_tests = os.getenv("SKIP_TESTS") +if test_all and specific_test: + print(f"WARNING: can't have both TEST_ALL and SPECIFIC_TEST at the same time...") + print(f"Running TEST_ALL instead") + specific_test = "" + print("========================================") print(f"db: {db}") print(f"workspace: {workspace_name}") print(f"test_workflows: {test_workflows}") print(f"test_all: {test_all}") +print(f"specific_test: {specific_test}") print(f"skip_tests: {skip_tests}") print("========================================") @@ -224,6 +231,7 @@ def download_files(force_redownload=False): modal.Image.debian_slim(python_version="3.11") .env({"COMFYUI_PATH": "/root", "COMFYUI_MODEL_PATH": "/root/models"}) .env({"TEST_ALL": os.getenv("TEST_ALL")}) + .env({"SPECIFIC_TEST": os.getenv("SPECIFIC_TEST")}) .apt_install("git", "git-lfs", "libgl1-mesa-glx", "libglib2.0-0", "libmagic1", "ffmpeg", "libegl1") .pip_install_from_pyproject(str(root_dir / "pyproject.toml")) .pip_install("diffusers==0.31.0") @@ -319,10 +327,11 @@ def downloads(self): @modal.build() def test_workflows(self): - print(" ==== TESTING WORKFLOWS ====") if os.getenv("SKIP_TESTS"): print("Skipping tests") return + + print(" ==== TESTING WORKFLOWS ====") t1 = time.time() self._start() @@ -348,6 +357,8 @@ def test_workflows(self): test_all = os.getenv("TEST_ALL", False) if test_all: tests = glob.glob(f"/root/workspace/workflows/{workflow}/test*.json") + elif specific_test: + tests = [f"/root/workspace/workflows/{workflow}/{specific_test}"] else: tests = [f"/root/workspace/workflows/{workflow}/test.json"] print(f"====> Running tests for {workflow}: ", tests) diff --git a/eve/api.py b/eve/api.py index a06002b..60d2fc8 100644 --- a/eve/api.py +++ b/eve/api.py @@ -1,6 +1,7 @@ import os import json import modal +import threading from fastapi import FastAPI, Depends, BackgroundTasks from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware @@ -32,6 +33,7 @@ from eve.agent import Agent from eve.user import User from eve.task import Task +from eve.tools.comfyui_tool import convert_tasks2_to_tasks3 from eve import deploy # Config and logging setup @@ -62,7 +64,9 @@ @web_app.on_event("startup") async def startup_event(): web_app.state.ably_client = AblyRealtime(os.getenv("ABLY_PUBLISHER_KEY")) - + watch_thread = threading.Thread(target=convert_tasks2_to_tasks3, daemon=True) + watch_thread.start() + logger.info("Started tasks2 watch thread.") class TaskRequest(BaseModel): tool: str @@ -141,6 +145,7 @@ async def task_admin( request: TaskRequest, _: dict = Depends(auth.authenticate_admin) ): + print("===== this is the request", request) result = await handle_task(request.tool, request.user_id, request.args) return serialize_document(result.model_dump(by_alias=True)) diff --git a/eve/tool.py b/eve/tool.py index 7cb3114..f46a91a 100644 --- a/eve/tool.py +++ b/eve/tool.py @@ -287,7 +287,8 @@ def calculate_cost(self, args): def prepare_args(self, args: dict): unrecognized_args = set(args.keys()) - set(self.model.model_fields.keys()) if unrecognized_args: - raise ValueError( + # raise ValueError( + print( f"Unrecognized arguments provided for {self.key}: {', '.join(unrecognized_args)}" ) diff --git a/eve/tools/comfyui_tool.py b/eve/tools/comfyui_tool.py index 33f5b4b..7af69e9 100644 --- a/eve/tools/comfyui_tool.py +++ b/eve/tools/comfyui_tool.py @@ -3,6 +3,7 @@ from bson import ObjectId from pydantic import BaseModel, Field from typing import List, Optional, Dict +import asyncio from ..mongo import get_collection from ..tool import Tool @@ -89,7 +90,6 @@ async def async_run(self, args: Dict): env=db ) result = {"output": result} - print(result) return result @Tool.handle_start_task @@ -114,23 +114,31 @@ async def async_start_task(self, task: Task): env=db ) return job.object_id - - @Tool.handle_wait - async def async_wait(self, task: Task): - # hack to accommodate legacy comfyui tasks - # 1) get completed task from tasks2 collection - # 2) copy over canonical test - # 3) return new task - fc = modal.functions.FunctionCall.from_id(task.handler_id) - await fc.get.aio() - task.reload() - tasks2 = get_collection("tasks2") - task2 = tasks2.find_one({"_id": task.id}) - task.update( - status=task2["status"], - error=task2["error"], - result=task2["result"] - ) - return task.model_dump(include={"status", "error", "result"}) +def convert_tasks2_to_tasks3(): + """ + This is hack to retrofit legacy ComfyUI tasks in tasks2 collection to new tasks3 records + """ + pipeline = [ + { + "$match": { + "operationType": {"$in": ["insert", "update", "replace"]} + } + } + ] + try: + tasks2 = get_collection("tasks2") + with tasks2.watch(pipeline) as stream: + for change in stream: + task_id = change["documentKey"]["_id"] + update = change["updateDescription"]["updatedFields"] + task = Task.from_mongo(task_id) + task.reload() + task.update( + status=update.get("status", task.status), + error=update.get("error", task.error), + result=update.get("result", task.result) + ) + except Exception as e: + print(f"Error in watch_tasks2 thread: {e}")