Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jmilldotdev committed Jan 7, 2025
2 parents 17152aa + d04a7cf commit 1aa3777
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 22 deletions.
20 changes: 19 additions & 1 deletion comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,29 @@
test_workflows = os.getenv("WORKFLOWS")
root_workflows_folder = "../private_workflows" if os.getenv("PRIVATE") else "../workflows"
test_all = True if os.getenv("TEST_ALL") else False
specific_test = os.getenv("SPECIFIC_TEST") if os.getenv("SPECIFIC_TEST") else ""
skip_tests = os.getenv("SKIP_TESTS")

# Run a bunch of checks to verify input args:
if test_all and specific_test:
print(f"WARNING: can't have both TEST_ALL and SPECIFIC_TEST at the same time...")
print(f"Running TEST_ALL instead")
specific_test = ""

print("========================================")
print(f"db: {db}")
print(f"workspace: {workspace_name}")
print(f"test_workflows: {test_workflows}")
print(f"test_all: {test_all}")
print(f"specific_test: {specific_test}")
print(f"skip_tests: {skip_tests}")
print("========================================")

if not test_workflows and workspace_name and not test_all:
print("\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("!!!! WARNING: You are deploying a workspace without TEST_ALL !!!!")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n")

def install_comfyui():
snapshot = json.load(open("/root/workspace/snapshot.json", 'r'))
comfyui_commit_sha = snapshot["comfyui"]
Expand Down Expand Up @@ -224,6 +237,7 @@ def download_files(force_redownload=False):
modal.Image.debian_slim(python_version="3.11")
.env({"COMFYUI_PATH": "/root", "COMFYUI_MODEL_PATH": "/root/models"})
.env({"TEST_ALL": os.getenv("TEST_ALL")})
.env({"SPECIFIC_TEST": os.getenv("SPECIFIC_TEST")})
.apt_install("git", "git-lfs", "libgl1-mesa-glx", "libglib2.0-0", "libmagic1", "ffmpeg", "libegl1")
.pip_install_from_pyproject(str(root_dir / "pyproject.toml"))
.pip_install("diffusers==0.31.0")
Expand Down Expand Up @@ -319,10 +333,11 @@ def downloads(self):

@modal.build()
def test_workflows(self):
print(" ==== TESTING WORKFLOWS ====")
if os.getenv("SKIP_TESTS"):
print("Skipping tests")
return

print(" ==== TESTING WORKFLOWS ====")

t1 = time.time()
self._start()
Expand All @@ -348,8 +363,11 @@ def test_workflows(self):
test_all = os.getenv("TEST_ALL", False)
if test_all:
tests = glob.glob(f"/root/workspace/workflows/{workflow}/test*.json")
elif specific_test:
tests = [f"/root/workspace/workflows/{workflow}/{specific_test}"]
else:
tests = [f"/root/workspace/workflows/{workflow}/test.json"]
print("\n\n-----------------------------------------------------------")
print(f"====> Running tests for {workflow}: ", tests)
for test in tests:
tool = Tool.from_yaml(f"/root/workspace/workflows/{workflow}/api.yaml")
Expand Down
146 changes: 145 additions & 1 deletion eve/api.py
Original file line number Diff line number Diff line change
@@ -1 +1,145 @@

import os
import threading
import modal
from fastapi import FastAPI, Depends, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader, HTTPBearer
import logging
from ably import AblyRealtime
from apscheduler.schedulers.background import BackgroundScheduler
from pathlib import Path

from eve import auth
from eve.api.handlers import (
handle_cancel,
handle_chat,
handle_create,
handle_deployment,
handle_schedule,
handle_stream_chat,
)
from eve.api.requests import (
CancelRequest,
ChatRequest,
ScheduleRequest,
TaskRequest,
DeployRequest,
)
from eve.deploy import (
authenticate_modal_key,
check_environment_exists,
create_environment,
)
from eve import deploy
from eve.tools.comfyui_tool import convert_tasks2_to_tasks3

# Config and logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

db = os.getenv("DB", "STAGE").upper()
if db not in ["PROD", "STAGE"]:
raise Exception(f"Invalid environment: {db}. Must be PROD or STAGE")
app_name = "api-prod" if db == "PROD" else "api-stage"

# FastAPI setup
web_app = FastAPI()
web_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
scheduler = BackgroundScheduler()
scheduler.start()

api_key_header = APIKeyHeader(name="X-Api-Key", auto_error=False)
bearer_scheme = HTTPBearer(auto_error=False)
background_tasks: BackgroundTasks = BackgroundTasks()


@web_app.on_event("startup")
async def startup_event():
watch_thread = threading.Thread(target=convert_tasks2_to_tasks3, daemon=True)
watch_thread.start()
web_app.state.ably_client = AblyRealtime(os.getenv("ABLY_PUBLISHER_KEY"))


@web_app.post("/create")
async def create(request: TaskRequest, _: dict = Depends(auth.authenticate_admin)):
return await handle_create(request)


@web_app.post("/cancel")
async def cancel(request: CancelRequest, _: dict = Depends(auth.authenticate_admin)):
return await handle_cancel(request)


@web_app.post("/chat")
async def chat(
request: ChatRequest,
background_tasks: BackgroundTasks,
_: dict = Depends(auth.authenticate_admin),
):
return await handle_chat(request, background_tasks, web_app.state.ably_client)


@web_app.post("/chat/stream")
async def stream_chat(
request: ChatRequest,
background_tasks: BackgroundTasks,
_: dict = Depends(auth.authenticate_admin),
):
return await handle_stream_chat(request, background_tasks)


@web_app.post("/deployment")
async def deployment(
request: DeployRequest, _: dict = Depends(auth.authenticate_admin)
):
return await handle_deployment(request)


@web_app.post("/schedule")
async def schedule(
request: ScheduleRequest, _: dict = Depends(auth.authenticate_admin)
):
return await handle_schedule(request)


# Modal app setup
app = modal.App(
name=app_name,
secrets=[
modal.Secret.from_name("eve-secrets"),
modal.Secret.from_name(f"eve-secrets-{db}"),
],
)

root_dir = Path(__file__).parent.parent
workflows_dir = root_dir / ".." / "workflows"

image = (
modal.Image.debian_slim(python_version="3.11")
.env({"DB": db, "MODAL_SERVE": os.getenv("MODAL_SERVE")})
.apt_install("git", "libmagic1", "ffmpeg", "wget")
.pip_install_from_pyproject(str(root_dir / "pyproject.toml"))
.run_commands(["playwright install"])
.copy_local_dir(str(workflows_dir), "/workflows")
)


@app.function(
image=image,
keep_warm=1,
concurrency_limit=10,
container_idle_timeout=60,
timeout=3600,
)
@modal.asgi_app()
def fastapi_app():
authenticate_modal_key()
if not check_environment_exists(deploy.DEPLOYMENT_ENV_NAME):
create_environment(deploy.DEPLOYMENT_ENV_NAME)
return web_app
3 changes: 2 additions & 1 deletion eve/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def calculate_cost(self, args):
def prepare_args(self, args: dict):
unrecognized_args = set(args.keys()) - set(self.model.model_fields.keys())
if unrecognized_args:
raise ValueError(
# raise ValueError(
print(
f"Unrecognized arguments provided for {self.key}: {', '.join(unrecognized_args)}"
)

Expand Down
47 changes: 28 additions & 19 deletions eve/tools/comfyui_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ async def async_run(self, args: Dict):
env=db
)
result = {"output": result}
print(result)
return result

@Tool.handle_start_task
Expand All @@ -114,23 +113,33 @@ async def async_start_task(self, task: Task):
env=db
)
return job.object_id

@Tool.handle_wait
async def async_wait(self, task: Task):
# hack to accommodate legacy comfyui tasks
# 1) get completed task from tasks2 collection
# 2) copy over canonical test
# 3) return new task
fc = modal.functions.FunctionCall.from_id(task.handler_id)
await fc.get.aio()
task.reload()

tasks2 = get_collection("tasks2")
task2 = tasks2.find_one({"_id": task.id})
task.update(
status=task2["status"],
error=task2["error"],
result=task2["result"]
)

return task.model_dump(include={"status", "error", "result"})
def convert_tasks2_to_tasks3():
"""
This is hack to retrofit legacy ComfyUI tasks in tasks2 collection to new tasks3 records
"""
pipeline = [
{
"$match": {
"operationType": {"$in": ["insert", "update", "replace"]}
}
}
]
try:
tasks2 = get_collection("tasks2")
with tasks2.watch(pipeline) as stream:
for change in stream:
task_id = change["documentKey"]["_id"]
if "updateDescription" not in change:
continue
update = change["updateDescription"]["updatedFields"]
task = Task.from_mongo(task_id)
task.reload()
task.update(
status=update.get("status", task.status),
error=update.get("error", task.error),
result=update.get("result", task.result)
)
except Exception as e:
print(f"Error in watch_tasks2 thread: {e}")

0 comments on commit 1aa3777

Please sign in to comment.