diff --git a/eve/__init__.py b/eve/__init__.py index bd17463..cbde5e4 100644 --- a/eve/__init__.py +++ b/eve/__init__.py @@ -14,6 +14,29 @@ ) +def setup_sentry(): + sentry_dsn = os.getenv("SENTRY_DSN") + if not sentry_dsn: + return + + sentry_env = "production" if db == "PROD" else "staging" + if db == "PROD": + traces_sample_rate = 0.1 + profiles_sample_rate = 0.05 + + else: + traces_sample_rate = 1.0 + profiles_sample_rate = 1.0 + + if sentry_dsn: + sentry_sdk.init( + dsn=sentry_dsn, + traces_sample_rate=traces_sample_rate, + profiles_sample_rate=profiles_sample_rate, + environment=sentry_env, + ) + + def load_env(db): global EDEN_API_KEY @@ -40,15 +63,7 @@ def load_env(db): load_dotenv(env_file, override=True) # start sentry - sentry_dsn = os.getenv("SENTRY_DSN") - sentry_env = "production" if db == "PROD" else "staging" - if sentry_dsn: - sentry_sdk.init( - dsn=sentry_dsn, - traces_sample_rate=1.0, - profiles_sample_rate=1.0, - environment=sentry_env, - ) + setup_sentry() # load api keys EDEN_API_KEY = SecretStr(os.getenv("EDEN_API_KEY", "")) diff --git a/eve/api/api.py b/eve/api/api.py index ecb0abb..54778f2 100644 --- a/eve/api/api.py +++ b/eve/api/api.py @@ -2,24 +2,20 @@ import os import threading import json -import asyncio import modal from fastapi import FastAPI, Depends, BackgroundTasks, Request -from fastapi import Request from fastapi.middleware.cors import CORSMiddleware from fastapi.security import APIKeyHeader, HTTPBearer -from ably import AblyRealtime from apscheduler.schedulers.background import BackgroundScheduler from pathlib import Path from contextlib import asynccontextmanager from eve import auth -from eve.api.helpers import load_existing_triggers from eve.postprocessing import ( generate_lora_thumbnails, cancel_stuck_tasks, download_nsfw_models, - run_nsfw_detection + run_nsfw_detection, ) from eve.api.handlers import ( handle_create, @@ -63,28 +59,17 @@ # FastAPI setup @asynccontextmanager async def lifespan(app: FastAPI): - from eve.api.handlers import handle_chat - - # Startup watch_thread = threading.Thread(target=convert_tasks2_to_tasks3, daemon=True) watch_thread.start() app.state.watch_thread = watch_thread - app.state.ably_client = AblyRealtime( - key=os.getenv("ABLY_PUBLISHER_KEY"), - ) - - # Load existing triggers - await load_existing_triggers(scheduler, app.state.ably_client, handle_chat) - - yield - # Shutdown - if hasattr(app.state, "watch_thread"): - app.state.watch_thread.join(timeout=5) - if hasattr(app.state, "scheduler"): - app.state.scheduler.shutdown(wait=True) - if hasattr(app.state, "ably_client"): - await app.state.ably_client.close() + try: + yield + finally: + if hasattr(app.state, "watch_thread"): + app.state.watch_thread.join(timeout=5) + if hasattr(app.state, "scheduler"): + app.state.scheduler.shutdown(wait=True) web_app = FastAPI(lifespan=lifespan) @@ -118,7 +103,7 @@ async def replicate_webhook(request: Request): # Get raw body for signature verification body = await request.body() print(body) - + # Parse JSON body try: data = json.loads(body) @@ -147,7 +132,7 @@ async def chat( background_tasks: BackgroundTasks, _: dict = Depends(auth.authenticate_admin), ): - return await handle_chat(request, background_tasks, web_app.state.ably_client) + return await handle_chat(request, background_tasks) @web_app.post("/chat/stream") @@ -177,7 +162,7 @@ async def deployment_delete( async def trigger_create( request: CreateTriggerRequest, _: dict = Depends(auth.authenticate_admin) ): - return await handle_trigger_create(request, scheduler, web_app.state.ably_client) + return await handle_trigger_create(request, scheduler) @web_app.post("/triggers/delete") @@ -204,18 +189,13 @@ async def trigger_delete( .env({"DB": db, "MODAL_SERVE": os.getenv("MODAL_SERVE")}) .apt_install("git", "libmagic1", "ffmpeg", "wget") .pip_install_from_pyproject(str(root_dir / "pyproject.toml")) - .pip_install( - "numpy<2.0", - "torch==2.0.1", - "torchvision", - "transformers", - "Pillow" - ) + .pip_install("numpy<2.0", "torch==2.0.1", "torchvision", "transformers", "Pillow") .run_commands(["playwright install"]) .run_function(download_nsfw_models) .copy_local_dir(str(workflows_dir), "/workflows") ) + @app.function( image=image, keep_warm=1, @@ -232,10 +212,7 @@ def fastapi_app(): @app.function( - image=image, - concurrency_limit=1, - schedule=modal.Period(minutes=15), - timeout=3600 + image=image, concurrency_limit=1, schedule=modal.Period(minutes=15), timeout=3600 ) async def postprocessing(): try: diff --git a/eve/api/handlers.py b/eve/api/handlers.py index eb15c95..faab0bf 100644 --- a/eve/api/handlers.py +++ b/eve/api/handlers.py @@ -3,7 +3,7 @@ import os import time from bson import ObjectId -from fastapi import BackgroundTasks, Request +from fastapi import BackgroundTasks from fastapi.responses import StreamingResponse from ably import AblyRealtime from apscheduler.schedulers.background import BackgroundScheduler @@ -20,9 +20,9 @@ ) from eve.api.helpers import ( emit_update, - get_update_channel, serialize_for_json, setup_chat, + AblyConnectionPool, ) from eve.deploy import ( create_modal_secrets, @@ -70,25 +70,15 @@ async def handle_replicate_webhook(body: dict): task = Task.from_handler_id(body["id"]) tool = Tool.load(task.tool) _ = replicate_update_task( - task, - body["status"], - body["error"], - body["output"], - tool.output_handler + task, body["status"], body["error"], body["output"], tool.output_handler ) async def handle_chat( - request: ChatRequest, - background_tasks: BackgroundTasks, - ably_client: AblyRealtime + request: ChatRequest, + background_tasks: BackgroundTasks, ): user, agent, thread, tools = await setup_chat(request, background_tasks) - update_channel = ( - await get_update_channel(request.update_config, ably_client) - if request.update_config and request.update_config.sub_channel_name - else None - ) async def run_prompt(): try: @@ -117,12 +107,11 @@ async def run_prompt(): elif update.type == UpdateType.ERROR: data["error"] = update.error if hasattr(update, "error") else None - await emit_update(request.update_config, update_channel, data) + await emit_update(request.update_config, data) except Exception as e: logger.error("Error in run_prompt", exc_info=True) await emit_update( request.update_config, - update_channel, {"type": "error", "error": str(e)}, ) @@ -202,7 +191,6 @@ async def handle_deployment_delete(request: DeleteDeploymentRequest): async def handle_trigger_create( request: CreateTriggerRequest, scheduler: BackgroundScheduler, - ably_client: AblyRealtime, ): from eve.trigger import create_chat_trigger @@ -215,7 +203,6 @@ async def handle_trigger_create( schedule=request.schedule.to_cron_dict(), update_config=request.update_config, scheduler=scheduler, - ably_client=ably_client, trigger_id=trigger_id, handle_chat_fn=handle_chat, ) diff --git a/eve/api/helpers.py b/eve/api/helpers.py index 794a05c..69fc293 100644 --- a/eve/api/helpers.py +++ b/eve/api/helpers.py @@ -4,9 +4,11 @@ import aiohttp from bson import ObjectId from fastapi import BackgroundTasks -from ably import AblyRealtime +from ably import AblyRest from apscheduler.schedulers.background import BackgroundScheduler import traceback +import asyncio +from contextlib import asynccontextmanager from eve.tool import Tool from eve.user import User @@ -21,8 +23,8 @@ async def get_update_channel( - update_config: UpdateConfig, ably_client: AblyRealtime -) -> Optional[AblyRealtime]: + update_config: UpdateConfig, ably_client: AblyRest +) -> Optional[AblyRest]: return ably_client.channels.get(str(update_config.sub_channel_name)) @@ -53,20 +55,19 @@ def serialize_for_json(obj): return obj -async def emit_update( - update_config: UpdateConfig, - update_channel: AblyRealtime, - data: dict -): - if update_config: - if update_config.update_endpoint and update_config.sub_channel_name: - raise ValueError( - "update_endpoint and sub_channel_name cannot be used together" - ) - elif update_config.update_endpoint: - await emit_http_update(update_config, data) - elif update_config.sub_channel_name: - await emit_channel_update(update_channel, data) +async def emit_update(update_config: UpdateConfig, data: dict): + if not update_config: + return + + if update_config.update_endpoint: + await emit_http_update(update_config, data) + elif update_config.sub_channel_name: + try: + client = AblyRest(os.getenv("ABLY_PUBLISHER_KEY")) + channel = client.channels.get(update_config.sub_channel_name) + await channel.publish(update_config.sub_channel_name, data) + except Exception as e: + logger.error(f"Failed to publish to Ably: {str(e)}") async def emit_http_update(update_config: UpdateConfig, data: dict): @@ -85,20 +86,8 @@ async def emit_http_update(update_config: UpdateConfig, data: dict): logger.error(f"Error sending update to endpoint: {str(e)}") -async def emit_channel_update( - update_channel: AblyRealtime, - data: dict -): - try: - await update_channel.publish("update", data) - except Exception as e: - logger.error(f"Failed to publish to Ably: {str(e)}") - - async def load_existing_triggers( - scheduler: BackgroundScheduler, - ably_client: AblyRealtime, - handle_chat_fn + scheduler: BackgroundScheduler, ably_client: AblyRest, handle_chat_fn ): """Load all existing triggers from the database and add them to the scheduler""" from ..trigger import create_chat_trigger @@ -116,7 +105,9 @@ async def load_existing_triggers( agent_id=str(trigger.agent), message=trigger.message, schedule=trigger.schedule, - update_config=UpdateConfig(**trigger.update_config) if trigger.update_config else None, + update_config=UpdateConfig(**trigger.update_config) + if trigger.update_config + else None, scheduler=scheduler, ably_client=ably_client, trigger_id=trigger.trigger_id,