From 25961c98bc516508864f3ea33da6c184011a39e1 Mon Sep 17 00:00:00 2001 From: Jonathan Miller Date: Tue, 14 Jan 2025 09:42:42 -0500 Subject: [PATCH 1/3] Add contexts for clients, tools, and middleware for API for error reporting --- eve/__init__.py | 33 ++++---- eve/api/api.py | 74 ++++++++++++++---- eve/clients/common.py | 21 +++++ eve/clients/discord/client.py | 22 ++++-- eve/clients/farcaster/client.py | 6 +- eve/clients/telegram/client.py | 48 +++++------- eve/tool.py | 78 ++++++++++--------- eve/tools/comfyui_tool.py | 59 ++++++-------- eve/tools/gcp_tool.py | 4 +- eve/tools/local_tool.py | 3 +- eve/tools/modal_tool.py | 4 +- eve/tools/replicate_tool.py | 134 ++++++++++++++++++-------------- 12 files changed, 284 insertions(+), 202 deletions(-) diff --git a/eve/__init__.py b/eve/__init__.py index 10421a0..eeb88df 100644 --- a/eve/__init__.py +++ b/eve/__init__.py @@ -17,24 +17,23 @@ def setup_sentry(): sentry_dsn = os.getenv("SENTRY_DSN") if not sentry_dsn: + print("No Sentry DSN found, skipping Sentry setup") return + print(f"Setting up sentry for {db}") + # Determine environment sentry_env = "production" if db == "PROD" else "staging" - if db == "PROD": - traces_sample_rate = 0.1 - profiles_sample_rate = 0.05 - - else: - traces_sample_rate = 1.0 - profiles_sample_rate = 1.0 - - if sentry_dsn: - sentry_sdk.init( - dsn=sentry_dsn, - traces_sample_rate=traces_sample_rate, - profiles_sample_rate=profiles_sample_rate, - environment=sentry_env, - ) + + # Set sampling rates + traces_sample_rate = 0.1 if db == "PROD" else 1.0 + profiles_sample_rate = 0.05 if db == "PROD" else 1.0 + + sentry_sdk.init( + dsn=sentry_dsn, + traces_sample_rate=traces_sample_rate, + profiles_sample_rate=profiles_sample_rate, + environment=sentry_env, + ) def load_env(db): @@ -92,6 +91,8 @@ def verify_env(): db = os.getenv("DB", "STAGE").upper() if db not in ["STAGE", "PROD", "WEB3-STAGE", "WEB3-PROD"]: - raise Exception(f"Invalid environment: {db}. Must be STAGE, PROD, WEB3-STAGE, or WEB3-PROD") + raise Exception( + f"Invalid environment: {db}. Must be STAGE, PROD, WEB3-STAGE, or WEB3-PROD" + ) load_env(db) diff --git a/eve/api/api.py b/eve/api/api.py index 2729662..7e74549 100644 --- a/eve/api/api.py +++ b/eve/api/api.py @@ -2,6 +2,7 @@ import os import threading import json +from fastapi.responses import JSONResponse import modal from fastapi import FastAPI, Depends, BackgroundTasks, Request from fastapi.middleware.cors import CORSMiddleware @@ -9,6 +10,8 @@ from apscheduler.schedulers.background import BackgroundScheduler from pathlib import Path from contextlib import asynccontextmanager +from starlette.middleware.base import BaseHTTPMiddleware +import sentry_sdk from eve import auth, db from eve.postprocessing import ( @@ -67,7 +70,34 @@ async def lifespan(app: FastAPI): app.state.scheduler.shutdown(wait=True) +class SentryContextMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + with sentry_sdk.configure_scope() as scope: + scope.set_tag("package", "eve-api") + + # Extract client context from headers + client_platform = request.headers.get("X-Client-Platform") + client_agent = request.headers.get("X-Client-Agent") + + if client_platform: + scope.set_tag("client_platform", client_platform) + if client_agent: + scope.set_tag("client_agent", client_agent) + + scope.set_context( + "api", + { + "endpoint": request.url.path, + "modal_serve": os.getenv("MODAL_SERVE"), + "client_platform": client_platform, + "client_agent": client_agent, + }, + ) + return await call_next(request) + + web_app = FastAPI(lifespan=lifespan) +web_app.add_middleware(SentryContextMiddleware) web_app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -167,6 +197,15 @@ async def trigger_delete( return await handle_trigger_delete(request, scheduler) +@web_app.exception_handler(Exception) +async def catch_all_exception_handler(request, exc): + sentry_sdk.capture_exception(exc) + return JSONResponse( + status_code=500, + content={"message": str(exc)}, + ) + + # Modal app setup app = modal.App( name=app_name, @@ -210,17 +249,24 @@ def fastapi_app(): image=image, concurrency_limit=1, schedule=modal.Period(minutes=15), timeout=3600 ) async def postprocessing(): - try: - await cancel_stuck_tasks() - except Exception as e: - print(f"Error cancelling stuck tasks: {e}") - - try: - await run_nsfw_detection() - except Exception as e: - print(f"Error running nsfw detection: {e}") - - try: - await generate_lora_thumbnails() - except Exception as e: - print(f"Error generating lora thumbnails: {e}") + with sentry_sdk.configure_scope() as scope: + scope.set_tag("component", "postprocessing") + scope.set_context("function", {"name": "postprocessing"}) + + try: + await cancel_stuck_tasks() + except Exception as e: + print(f"Error cancelling stuck tasks: {e}") + sentry_sdk.capture_exception(e) + + try: + await run_nsfw_detection() + except Exception as e: + print(f"Error running nsfw detection: {e}") + sentry_sdk.capture_exception(e) + + try: + await generate_lora_thumbnails() + except Exception as e: + print(f"Error generating lora thumbnails: {e}") + sentry_sdk.capture_exception(e) diff --git a/eve/clients/common.py b/eve/clients/common.py index ac72f32..d90bdcf 100644 --- a/eve/clients/common.py +++ b/eve/clients/common.py @@ -1,7 +1,9 @@ import os import time +import asyncio from eve.models import ClientType +from .. import sentry_sdk db = os.getenv("DB", "STAGE") @@ -53,6 +55,25 @@ day_timestamps = {} +def client_context(client_platform: str): + """Decorator to add client context to Sentry for all async methods""" + + def decorator(cls): + for name, method in cls.__dict__.items(): + if asyncio.iscoroutinefunction(method): + + async def wrapped_method(self, *args, __method=method, **kwargs): + with sentry_sdk.configure_scope() as scope: + scope.set_tag("client_platform", client_platform) + scope.set_tag("client_agent", self.agent.username) + return await __method(self, *args, **kwargs) + + setattr(cls, name, wrapped_method) + return cls + + return decorator + + def user_over_rate_limits(user): user_id = str(user.id) diff --git a/eve/clients/discord/client.py b/eve/clients/discord/client.py index de527bd..9c4cf12 100644 --- a/eve/clients/discord/client.py +++ b/eve/clients/discord/client.py @@ -40,6 +40,7 @@ def replace_mentions_with_usernames( return message_content.strip() +@common.client_context("discord") class Eden2Cog(commands.Cog): def __init__( self, @@ -285,7 +286,11 @@ async def on_message(self, message: discord.Message) -> None: async with session.post( f"{self.api_url}/chat", json=request_data, - headers={"Authorization": f"Bearer {os.getenv('EDEN_ADMIN_KEY')}"}, + headers={ + "Authorization": f"Bearer {os.getenv('EDEN_ADMIN_KEY')}", + "X-Client-Platform": "discord", + "X-Client-Agent": self.agent.username, + }, ) as response: if response.status != 200: error_msg = await response.text() @@ -394,12 +399,15 @@ def start( agent_name = os.getenv("EDEN_AGENT_USERNAME") agent = Agent.load(agent_name) - logger.info(f"Launching Discord bot {agent.username}...") - - bot_token = os.getenv("CLIENT_DISCORD_TOKEN") - bot = DiscordBot() - bot.add_cog(Eden2Cog(bot, agent, local=local)) - bot.run(bot_token) + with sentry_sdk.configure_scope() as scope: + scope.set_tag("client", "discord") + scope.set_context("discord", {"agent": agent_name, "local": local}) + logger.info(f"Launching Discord bot {agent.username}...") + + bot_token = os.getenv("CLIENT_DISCORD_TOKEN") + bot = DiscordBot() + bot.add_cog(Eden2Cog(bot, agent, local=local)) + bot.run(bot_token) except Exception as e: logger.error("Failed to start Discord bot", exc_info=True) sentry_sdk.capture_exception(e) diff --git a/eve/clients/farcaster/client.py b/eve/clients/farcaster/client.py index eadb7ce..14cdb45 100644 --- a/eve/clients/farcaster/client.py +++ b/eve/clients/farcaster/client.py @@ -214,7 +214,11 @@ async def process_webhook( async with session.post( f"{api_url}/chat", json=request_data, - headers={"Authorization": f"Bearer {os.getenv('EDEN_ADMIN_KEY')}"}, + headers={ + "Authorization": f"Bearer {os.getenv('EDEN_ADMIN_KEY')}", + "X-Client-Platform": "farcaster", + "X-Client-Agent": agent.username, + }, ) as response: if response.status != 200: raise Exception("Failed to process request") diff --git a/eve/clients/telegram/client.py b/eve/clients/telegram/client.py index f93b3b1..b7a48b0 100644 --- a/eve/clients/telegram/client.py +++ b/eve/clients/telegram/client.py @@ -16,7 +16,6 @@ ) import asyncio -from ... import load_env from ...clients import common from ...agent import Agent from ...llm import UpdateType @@ -94,10 +93,7 @@ def replace_bot_mentions(message_text: str, bot_username: str, replacement: str) async def send_response( - message_type: str, - chat_id: int, - response: list, - context: ContextTypes.DEFAULT_TYPE + message_type: str, chat_id: int, response: list, context: ContextTypes.DEFAULT_TYPE ): """ Send messages, photos, or videos based on the type of response. @@ -117,13 +113,9 @@ async def send_response( await context.bot.send_message(chat_id=chat_id, text=item) +@common.client_context("telegram") class EdenTG: - def __init__( - self, - token: str, - agent: Agent, - local: bool = False - ): + def __init__(self, token: str, agent: Agent, local: bool = False): self.token = token self.agent = agent self.tools = agent.get_tools() @@ -132,7 +124,7 @@ def __init__( if local: self.api_url = "http://localhost:8000" else: - self.api_url = os.getenv(f"EDEN_API_URL") + self.api_url = os.getenv("EDEN_API_URL") self.channel_name = common.get_ably_channel_name( agent.name, ClientType.TELEGRAM ) @@ -140,7 +132,7 @@ def __init__( # Don't initialize Ably here - we'll do it in setup_ably self.ably_client = None self.channel = None - + self.typing_tasks = {} async def initialize(self, application): @@ -161,8 +153,12 @@ async def _typing_loop(self, chat_id: int, application: Application): """Keep sending typing action until stopped""" try: while True: - await application.bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING) - await asyncio.sleep(5) # Telegram typing status expires after ~5 seconds + await application.bot.send_chat_action( + chat_id=chat_id, action=ChatAction.TYPING + ) + await asyncio.sleep( + 5 + ) # Telegram typing status expires after ~5 seconds except asyncio.CancelledError: pass @@ -269,23 +265,18 @@ async def echo(self, update: Update, context: ContextTypes.DEFAULT_TYPE): if is_direct_message: # print author - force_reply = False # No DMs + force_reply = False # No DMs return # Lookup thread thread_key = f"telegram-{chat_id}" if thread_key not in self.known_threads: - self.known_threads[thread_key] = self.agent.request_thread( - key=thread_key - ) + self.known_threads[thread_key] = self.agent.request_thread(key=thread_key) thread = self.known_threads[thread_key] # Lookup user if user_id not in self.known_users: - self.known_users[user_id] = User.from_telegram( - user_id, - username - ) + self.known_users[user_id] = User.from_telegram(user_id, username) user = self.known_users[user_id] # Check if user rate limits @@ -333,7 +324,11 @@ async def echo(self, update: Update, context: ContextTypes.DEFAULT_TYPE): async with session.post( f"{self.api_url}/chat", json=request_data, - headers={"Authorization": f"Bearer {os.getenv('EDEN_ADMIN_KEY')}"}, + headers={ + "Authorization": f"Bearer {os.getenv('EDEN_ADMIN_KEY')}", + "X-Client-Platform": "telegram", + "X-Client-Agent": self.agent.username, + }, ) as response: print(f"Response from {self.api_url}/chat: {response.status}") # json @@ -348,10 +343,7 @@ async def echo(self, update: Update, context: ContextTypes.DEFAULT_TYPE): return -def start( - env: str, - local: bool = False -) -> None: +def start(env: str, local: bool = False) -> None: print("Starting Telegram client...") load_dotenv(env) diff --git a/eve/tool.py b/eve/tool.py index 262c134..6ae3fc7 100644 --- a/eve/tool.py +++ b/eve/tool.py @@ -20,14 +20,7 @@ OUTPUT_TYPES = Literal[ - "boolean", - "string", - "integer", - "float", - "image", - "video", - "audio", - "lora" + "boolean", "string", "integer", "float", "image", "video", "audio", "lora" ] BASE_MODELS = Literal[ @@ -44,17 +37,10 @@ "runway", "mmaudio", "librosa", - "musicgen" + "musicgen", ] -HANDLERS = Literal[ - "local", - "modal", - "comfyui", - "comfyui_legacy", - "replicate", - "gcp" -] +HANDLERS = Literal["local", "modal", "comfyui", "comfyui_legacy", "replicate", "gcp"] @Collection("tools3") @@ -89,7 +75,7 @@ class Tool(Document, ABC): @classmethod def _get_schema(cls, key, from_yaml=False) -> dict: """Get schema for a tool, with detailed performance logging.""" - + if from_yaml: # YAML path api_files = get_api_files() @@ -111,11 +97,7 @@ def _get_schema(cls, key, from_yaml=False) -> dict: return schema @classmethod - def get_sub_class( - cls, - schema, - from_yaml=False - ) -> type: + def get_sub_class(cls, schema, from_yaml=False) -> type: from .tools.local_tool import LocalTool from .tools.modal_tool import ModalTool from .tools.comfyui_tool import ComfyUITool, ComfyUIToolLegacy @@ -133,7 +115,7 @@ def get_sub_class( "local": LocalTool, "modal": ModalTool, "comfyui": ComfyUITool, - "comfyui_legacy": ComfyUIToolLegacy, # private/legacy workflows + "comfyui_legacy": ComfyUIToolLegacy, # private/legacy workflows "replicate": ReplicateTool, "gcp": GCPTool, None: LocalTool, @@ -180,7 +162,7 @@ def convert_from_yaml(cls, schema: dict, file_path: str = None) -> dict: schema["test_args"] = json.load(f) return schema - + @classmethod def convert_from_mongo(cls, schema) -> dict: schema["parameters"] = { @@ -235,7 +217,7 @@ def from_mongo(cls, document_id, cache=False): return _tool_cache[str(document_id)] else: return super().from_mongo(document_id) - + @classmethod def load(cls, key, cache=False): if cache: @@ -268,7 +250,9 @@ def openai_schema(self, exclude_hidden: bool = False) -> dict[str, Any]: schema = openai_schema(self.model).openai_schema if exclude_hidden: self._remove_hidden_fields(schema["parameters"]) - schema["description"] = schema["description"][:1024] # OpenAI tool description limit + schema["description"] = schema["description"][ + :1024 + ] # OpenAI tool description limit return {"type": "function", "function": schema} def calculate_cost(self, args): @@ -282,7 +266,9 @@ def calculate_cost(self, args): r"(\w+)\s*\?\s*([^:]+)\s*:\s*([^,\s]+)", r"\2 if \1 else \3", cost_formula ) # Ternary operator cost_estimate = eval(cost_formula, args.copy()) - assert isinstance(cost_estimate, (int, float)), f"Cost estimate ({cost_estimate}) not a number (formula: {cost_formula})" + assert isinstance( + cost_estimate, (int, float) + ), f"Cost estimate ({cost_estimate}) not a number (formula: {cost_formula})" return cost_estimate def prepare_args(self, args: dict): @@ -354,7 +340,7 @@ async def async_wrapper( try: # validate args and user manna balance args = self.prepare_args(args) - sentry_sdk.add_breadcrumb(category="handle_start_task", data=args) + sentry_sdk.add_breadcrumb(category="handle_start_task", data=args) cost = self.calculate_cost(args) user = User.from_mongo(user_id) if "freeTools" in (user.featureFlags or []): @@ -378,7 +364,9 @@ async def async_wrapper( cost=cost, ) task.save() - sentry_sdk.add_breadcrumb(category="handle_start_task", data=task.model_dump()) + sentry_sdk.add_breadcrumb( + category="handle_start_task", data=task.model_dump() + ) # start task try: @@ -401,7 +389,7 @@ async def async_wrapper( task.update(handler_id=handler_id) task.spend_manna() - + except Exception as e: print(traceback.format_exc()) task.update(status="failed", error=str(e)) @@ -445,6 +433,7 @@ async def async_wrapper(self, task: Task, force: bool = False): task.update(status="failed", error="Timed out") else: task.update(status="cancelled") + return async_wrapper @abstractmethod @@ -479,10 +468,10 @@ def cancel(self, task: Task, force: bool = False): def get_tools_from_api_files( - root_dir: str = None, - tools: List[str] = None, + root_dir: str = None, + tools: List[str] = None, include_inactive: bool = False, - cache: bool = False + cache: bool = False, ) -> Dict[str, Tool]: """Get all tools inside a directory""" @@ -505,7 +494,7 @@ def get_tools_from_mongo( cache: bool = False, ) -> Dict[str, Tool]: """Get all tools from mongo""" - + tools_collection = get_collection(Tool.collection_name) # Batch fetch all tools and their parents @@ -553,5 +542,24 @@ def get_api_files(root_dir: str = None) -> List[str]: return api_files + +def tool_context(tool_type): + def decorator(cls): + for name, method in cls.__dict__.items(): + if asyncio.iscoroutinefunction(method): + + async def wrapped_method(self, *args, __method=method, **kwargs): + with sentry_sdk.configure_scope() as scope: + scope.set_tag("package", "eve-tools") + scope.set_tag("tool_type", tool_type) + scope.set_tag("tool_name", self.key) + return await __method(self, *args, **kwargs) + + setattr(cls, name, wrapped_method) + return cls + + return decorator + + # Tool cache for fetching commonly used tools _tool_cache: Dict[str, Dict[str, Tool]] = {} diff --git a/eve/tools/comfyui_tool.py b/eve/tools/comfyui_tool.py index 98beec6..8c63cfd 100644 --- a/eve/tools/comfyui_tool.py +++ b/eve/tools/comfyui_tool.py @@ -5,16 +5,18 @@ from typing import List, Optional, Dict from ..mongo import get_collection -from ..tool import Tool +from ..tool import Tool, tool_context from ..task import Task +@tool_context("comfyui") class ComfyUIRemap(BaseModel): node_id: int field: str subfield: str map: Dict[str, str] + class ComfyUIInfo(BaseModel): node_id: int field: str @@ -22,6 +24,7 @@ class ComfyUIInfo(BaseModel): preprocessing: Optional[str] = None remap: Optional[List[ComfyUIRemap]] = None + class ComfyUITool(Tool): workspace: str comfyui_output_node_id: int @@ -31,19 +34,20 @@ class ComfyUITool(Tool): @classmethod def convert_from_yaml(cls, schema: dict, file_path: str = None) -> dict: schema["comfyui_map"] = {} - for field, props in schema.get('parameters', {}).items(): - if 'comfyui' in props: - schema["comfyui_map"][field] = props['comfyui'] - schema["workspace"] = schema.get("workspace") or file_path.replace("api.yaml", "test.json").split("/")[-4] + for field, props in schema.get("parameters", {}).items(): + if "comfyui" in props: + schema["comfyui_map"][field] = props["comfyui"] + schema["workspace"] = ( + schema.get("workspace") + or file_path.replace("api.yaml", "test.json").split("/")[-4] + ) return super().convert_from_yaml(schema, file_path) - + @Tool.handle_run async def async_run(self, args: Dict): db = os.getenv("DB") cls = modal.Cls.lookup( - f"comfyui-{self.workspace}-{db}", - "ComfyUI", - environment_name="main" + f"comfyui-{self.workspace}-{db}", "ComfyUI", environment_name="main" ) result = await cls().run.remote.aio(self.parent_tool or self.key, args) return result @@ -52,9 +56,7 @@ async def async_run(self, args: Dict): async def async_start_task(self, task: Task): db = os.getenv("DB") cls = modal.Cls.lookup( - f"comfyui-{self.workspace}-{db}", - "ComfyUI", - environment_name="main" + f"comfyui-{self.workspace}-{db}", "ComfyUI", environment_name="main" ) job = await cls().run_task.spawn.aio(task) return job.object_id @@ -65,7 +67,7 @@ async def async_wait(self, task: Task): await fc.get.aio() task.reload() return task.model_dump(include={"status", "error", "result"}) - + @Tool.handle_cancel async def async_cancel(self, task: Task): fc = modal.functions.FunctionCall.from_id(task.handler_id) @@ -74,20 +76,14 @@ async def async_cancel(self, task: Task): class ComfyUIToolLegacy(ComfyUITool): """For legacy/private workflows""" - + @Tool.handle_run async def async_run(self, args: Dict): db = os.getenv("DB") cls = modal.Cls.lookup( - f"comfyui-{self.key}", - "ComfyUI", - environment_name="main" - ) - result = await cls().run.remote.aio( - workflow_name=self.key, - args=args, - env=db + f"comfyui-{self.key}", "ComfyUI", environment_name="main" ) + result = await cls().run.remote.aio(workflow_name=self.key, args=args, env=db) result = {"output": result} return result @@ -104,14 +100,9 @@ async def async_start_task(self, task: Task): tasks2.insert_one(task_data) cls = modal.Cls.lookup( - f"comfyui-{self.key}", - "ComfyUI", - environment_name="main" - ) - job = await cls().run_task.spawn.aio( - task_id=ObjectId(task_data["_id"]), - env=db + f"comfyui-{self.key}", "ComfyUI", environment_name="main" ) + job = await cls().run_task.spawn.aio(task_id=ObjectId(task_data["_id"]), env=db) return job.object_id @@ -119,13 +110,7 @@ def convert_tasks2_to_tasks3(): """ This is hack to retrofit legacy ComfyUI tasks in tasks2 collection to new tasks3 records """ - pipeline = [ - { - "$match": { - "operationType": {"$in": ["insert", "update", "replace"]} - } - } - ] + pipeline = [{"$match": {"operationType": {"$in": ["insert", "update", "replace"]}}}] try: tasks2 = get_collection("tasks2") with tasks2.watch(pipeline) as stream: @@ -139,7 +124,7 @@ def convert_tasks2_to_tasks3(): task.update( status=update.get("status", task.status), error=update.get("error", task.error), - result=update.get("result", task.result) + result=update.get("result", task.result), ) except Exception as e: print(f"Error in watch_tasks2 thread: {e}") diff --git a/eve/tools/gcp_tool.py b/eve/tools/gcp_tool.py index ad994ba..101ff3a 100644 --- a/eve/tools/gcp_tool.py +++ b/eve/tools/gcp_tool.py @@ -4,10 +4,10 @@ from google.oauth2 import service_account from google.cloud import aiplatform -from ..tool import Tool +from ..tool import Tool, tool_context from ..task import Task - +@tool_context("gcp") class GCPTool(Tool): gcr_image_uri: str machine_type: str diff --git a/eve/tools/local_tool.py b/eve/tools/local_tool.py index 2279fbd..2b227c0 100644 --- a/eve/tools/local_tool.py +++ b/eve/tools/local_tool.py @@ -3,10 +3,11 @@ import asyncio from ..task import Task, task_handler_func -from ..tool import Tool +from ..tool import Tool, tool_context from .tool_handlers import handlers +@tool_context("local") class LocalTool(Tool): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/eve/tools/modal_tool.py b/eve/tools/modal_tool.py index f2f6a7c..cede381 100644 --- a/eve/tools/modal_tool.py +++ b/eve/tools/modal_tool.py @@ -3,9 +3,9 @@ from typing import Dict from ..task import Task -from ..tool import Tool - +from ..tool import Tool, tool_context +@tool_context("modal") class ModalTool(Tool): @Tool.handle_run async def async_run(self, args: Dict): diff --git a/eve/tools/replicate_tool.py b/eve/tools/replicate_tool.py index c3f847f..6b58334 100644 --- a/eve/tools/replicate_tool.py +++ b/eve/tools/replicate_tool.py @@ -11,41 +11,40 @@ from .. import s3 from .. import eden_utils -from ..tool import Tool +from ..tool import Tool, tool_context from ..models import Model from ..task import Task, Creation from ..mongo import get_collection - + +@tool_context("replicate") class ReplicateTool(Tool): replicate_model: str replicate_model_substitutions: Optional[Dict[str, str]] = None version: Optional[str] = Field(None, description="Replicate version to use") output_handler: str = "normal" - + @Tool.handle_run async def async_run(self, args: Dict): check_replicate_api_token() if self.version: args = self._format_args_for_replicate(args) - prediction = await self._create_prediction(args, webhook=False) + prediction = await self._create_prediction(args, webhook=False) prediction.wait() if self.output_handler == "eden": result = {"output": prediction.output[-1]["files"][0]} elif self.output_handler == "trainer": result = { "output": prediction.output[-1]["files"][0], - "thumbnail": prediction.output[-1]["thumbnails"][0] + "thumbnail": prediction.output[-1]["thumbnails"][0], } else: result = {"output": prediction.output} else: replicate_model = self._get_replicate_model(args) args = self._format_args_for_replicate(args) - result = { - "output": replicate.run(replicate_model, input=args) - } - + result = {"output": replicate.run(replicate_model, input=args)} + result = eden_utils.upload_result(result) return result @@ -63,13 +62,11 @@ async def async_start_task(self, task: Task, webhook: bool = True): # So we spawn a remote task on Modal which awaits the Replicate task db = os.getenv("DB", "STAGE").upper() func = modal.Function.lookup( - f"remote-replicate-{db}", - "run_task", - environment_name="main" + f"remote-replicate-{db}", "run_task", environment_name="main" ) job = func.spawn(task) return job.object_id - + @Tool.handle_wait async def async_wait(self, task: Task): if self.version is None: @@ -85,10 +82,10 @@ async def async_wait(self, task: Task): status = prediction.status result = replicate_update_task( task, - status, - prediction.error, - prediction.output, - self.output_handler + status, + prediction.error, + prediction.output, + self.output_handler, ) if result["status"] in ["failed", "cancelled", "completed"]: return result @@ -105,15 +102,19 @@ def _format_args_for_replicate(self, args: dict): new_args = {k: v for k, v in new_args.items() if v is not None} for field in self.model.model_fields.keys(): parameter = self.parameters[field] - is_array = parameter.get('type') == 'array' - is_number = parameter.get('type') in ['integer', 'float'] - alias = parameter.get('alias') - lora = parameter.get('type') == 'lora' - + is_array = parameter.get("type") == "array" + is_number = parameter.get("type") in ["integer", "float"] + alias = parameter.get("alias") + lora = parameter.get("type") == "lora" + if field in new_args: if lora: loras = get_collection(Model.collection_name) - lora_doc = loras.find_one({"_id": ObjectId(args[field])}) if args[field] else None + lora_doc = ( + loras.find_one({"_id": ObjectId(args[field])}) + if args[field] + else None + ) if lora_doc: lora_url = s3.get_full_url(lora_doc.get("checkpoint")) lora_name = lora_doc.get("name") @@ -122,7 +123,9 @@ def _format_args_for_replicate(self, args: dict): if "prompt" in new_args: name_pattern = f"(\\b{re.escape(lora_name)}\\b|<{re.escape(lora_name)}>|\\)" pattern = re.compile(name_pattern, re.IGNORECASE) - new_args["prompt"] = pattern.sub(lora_trigger_text, new_args['prompt']) + new_args["prompt"] = pattern.sub( + lora_trigger_text, new_args["prompt"] + ) if is_number: new_args[field] = float(args[field]) elif is_array: @@ -145,16 +148,16 @@ def _get_replicate_model(self, args: dict): async def _create_prediction(self, args: dict, webhook=True): replicate_model = self._get_replicate_model(args) - user, model = replicate_model.split('/', 1) + user, model = replicate_model.split("/", 1) webhook_url = get_webhook_url() if webhook else None webhook_events_filter = ["start", "completed"] if webhook else None - + if self.version == "deployment": deployment = await replicate.deployments.async_get(f"{user}/{model}") prediction = await deployment.predictions.async_create( input=args, webhook=webhook_url, - webhook_events_filter=webhook_events_filter + webhook_events_filter=webhook_events_filter, ) else: model = await replicate.models.async_get(f"{user}/{model}") @@ -163,18 +166,24 @@ async def _create_prediction(self, args: dict, webhook=True): version=version, input=args, webhook=webhook_url, - webhook_events_filter=webhook_events_filter + webhook_events_filter=webhook_events_filter, ) return prediction + def get_webhook_url(): env = { - "PROD": "api-prod", - "STAGE": "api-stage", - "WEB3-PROD": "api-web3-prod", - "WEB3-STAGE": "api-web3-stage" + "PROD": "api-prod", + "STAGE": "api-stage", + "WEB3-PROD": "api-web3-prod", + "WEB3-STAGE": "api-web3-stage", }.get(os.getenv("DB"), "api-web3-stage") - dev = "-dev" if os.getenv("DB") in ["WEB3-STAGE", "STAGE"] and os.getenv("MODAL_SERVE") == "1" else "" + dev = ( + "-dev" + if os.getenv("DB") in ["WEB3-STAGE", "STAGE"] + and os.getenv("MODAL_SERVE") == "1" + else "" + ) webhook_url = f"https://edenartlab--{env}-fastapi-app{dev}.modal.run/update" return webhook_url @@ -186,7 +195,7 @@ def replicate_update_task(task: Task, status, error, output, output_handler): if output and isinstance(output[0], replicate.helpers.FileOutput): output_files = [] for out in output: - with tempfile.NamedTemporaryFile(suffix='.webp', delete=False) as temp_file: + with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_file: temp_file.write(out.read()) output_files.append(temp_file.name) output = output_files @@ -195,12 +204,12 @@ def replicate_update_task(task: Task, status, error, output, output_handler): task.update(status="failed", error=error) task.refund_manna() return {"status": "failed", "error": error} - + elif status == "canceled": task.update(status="cancelled") task.refund_manna() return {"status": "cancelled"} - + elif status == "processing": task.performance["waitTime"] = ( datetime.now(timezone.utc) - task.createdAt.replace(tzinfo=timezone.utc) @@ -208,26 +217,32 @@ def replicate_update_task(task: Task, status, error, output, output_handler): task.status = "running" task.save() return {"status": "running"} - + elif status == "succeeded": if output_handler in ["eden", "trainer"]: thumbnails = output[-1]["thumbnails"] output = output[-1]["files"] - output = eden_utils.upload_result(output, save_thumbnails=True, save_blurhash=True) + output = eden_utils.upload_result( + output, save_thumbnails=True, save_blurhash=True + ) result = [{"output": [out]} for out in output] else: - output = eden_utils.upload_result(output, save_thumbnails=True, save_blurhash=True) + output = eden_utils.upload_result( + output, save_thumbnails=True, save_blurhash=True + ) result = [{"output": [out]} for out in output] - + for r, res in enumerate(result): for o, output in enumerate(res["output"]): if output_handler == "trainer": filename = output["filename"] - thumbnail = eden_utils.upload_media( - thumbnails[0], - save_thumbnails=False, - save_blurhash=False - ) if thumbnails else None + thumbnail = ( + eden_utils.upload_media( + thumbnails[0], save_thumbnails=False, save_blurhash=False + ) + if thumbnails + else None + ) url = s3.get_full_url(filename) checkpoint_filename = url.split("/")[-1] model = Model( @@ -237,20 +252,24 @@ def replicate_update_task(task: Task, status, error, output, output_handler): task=task.id, thumbnail=thumbnail.get("filename"), args=task.args, - checkpoint=checkpoint_filename, + checkpoint=checkpoint_filename, base_model="sdxl", ) - model.save(upsert_filter={"task": ObjectId(task.id)}) # upsert_filter prevents duplicates + model.save( + upsert_filter={"task": ObjectId(task.id)} + ) # upsert_filter prevents duplicates output["model"] = model.id - + # This is a hack to support legacy models for private endpoints. # Change filename to url and copy record to the old models collection if str(task.user) == os.getenv("LEGACY_USER_ID"): model_copy = model.model_dump(by_alias=True) - model_copy["checkpoint"] = s3.get_full_url(model_copy["checkpoint"]) + model_copy["checkpoint"] = s3.get_full_url( + model_copy["checkpoint"] + ) model_copy["slug"] = f"legacy/{str(model_copy['_id'])}" get_collection("models").insert_one(model_copy) - + else: name = task.args.get("prompt") creation = Creation( @@ -258,31 +277,28 @@ def replicate_update_task(task: Task, status, error, output, output_handler): requester=task.requester, task=task.id, tool=task.tool, - filename=output['filename'], + filename=output["filename"], mediaAttributes=output["mediaAttributes"], - name=name + name=name, ) creation.save() result[r]["output"][o]["creation"] = creation.id - + run_time = ( datetime.now(timezone.utc) - task.createdAt.replace(tzinfo=timezone.utc) ).total_seconds() if task.performance.get("waitTime"): run_time -= task.performance["waitTime"] task.performance["runTime"] = run_time - + result = result if isinstance(result, list) else [result] task.status = "completed" task.result = result task.save() - return { - "status": "completed", - "result": result - } + return {"status": "completed", "result": result} def check_replicate_api_token(): if not os.getenv("REPLICATE_API_TOKEN"): - raise Exception("REPLICATE_API_TOKEN is not set") \ No newline at end of file + raise Exception("REPLICATE_API_TOKEN is not set") From 637edd21c7bcdbeeeca0766217f50829b3563192 Mon Sep 17 00:00:00 2001 From: Jonathan Miller Date: Tue, 14 Jan 2025 09:49:10 -0500 Subject: [PATCH 2/3] Add scope to sentry client --- eve/clients/telegram/client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/eve/clients/telegram/client.py b/eve/clients/telegram/client.py index b7a48b0..de141a2 100644 --- a/eve/clients/telegram/client.py +++ b/eve/clients/telegram/client.py @@ -4,6 +4,7 @@ from ably import AblyRealtime import aiohttp from dotenv import load_dotenv +import sentry_sdk from telegram import Update from telegram.constants import ChatAction from telegram.ext import ( @@ -350,6 +351,10 @@ def start(env: str, local: bool = False) -> None: agent_name = os.getenv("EDEN_AGENT_USERNAME") agent = Agent.load(agent_name) + with sentry_sdk.configure_scope() as scope: + scope.set_tag("client_platform", "telegram") + scope.set_tag("client_agent", agent_name) + bot_token = os.getenv("CLIENT_TELEGRAM_TOKEN") if not bot_token: raise ValueError("CLIENT_TELEGRAM_TOKEN not found in environment variables") From 09f999ace24c500bdffb3642e45d2f6c150c5562 Mon Sep 17 00:00:00 2001 From: Jonathan Miller Date: Tue, 14 Jan 2025 09:50:32 -0500 Subject: [PATCH 3/3] more sentry context client init --- eve/clients/discord/client.py | 4 +++- eve/clients/telegram/client.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/eve/clients/discord/client.py b/eve/clients/discord/client.py index 9c4cf12..66f8284 100644 --- a/eve/clients/discord/client.py +++ b/eve/clients/discord/client.py @@ -400,7 +400,9 @@ def start( agent_name = os.getenv("EDEN_AGENT_USERNAME") agent = Agent.load(agent_name) with sentry_sdk.configure_scope() as scope: - scope.set_tag("client", "discord") + scope.set_tag("package", "eve-clients") + scope.set_tag("client_platform", "discord") + scope.set_tag("client_agent", agent_name) scope.set_context("discord", {"agent": agent_name, "local": local}) logger.info(f"Launching Discord bot {agent.username}...") diff --git a/eve/clients/telegram/client.py b/eve/clients/telegram/client.py index de141a2..ab47028 100644 --- a/eve/clients/telegram/client.py +++ b/eve/clients/telegram/client.py @@ -352,8 +352,10 @@ def start(env: str, local: bool = False) -> None: agent = Agent.load(agent_name) with sentry_sdk.configure_scope() as scope: + scope.set_tag("package", "eve-clients") scope.set_tag("client_platform", "telegram") scope.set_tag("client_agent", agent_name) + scope.set_context("telegram", {"agent": agent_name, "local": local}) bot_token = os.getenv("CLIENT_TELEGRAM_TOKEN") if not bot_token: