Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jmill/work #55

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions eve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,29 @@
)


def setup_sentry():
sentry_dsn = os.getenv("SENTRY_DSN")
if not sentry_dsn:
return

sentry_env = "production" if db == "PROD" else "staging"
if db == "PROD":
traces_sample_rate = 0.1
profiles_sample_rate = 0.05

else:
traces_sample_rate = 1.0
profiles_sample_rate = 1.0

if sentry_dsn:
sentry_sdk.init(
dsn=sentry_dsn,
traces_sample_rate=traces_sample_rate,
profiles_sample_rate=profiles_sample_rate,
environment=sentry_env,
)


def load_env(db):
global EDEN_API_KEY

Expand All @@ -40,15 +63,7 @@ def load_env(db):
load_dotenv(env_file, override=True)

# start sentry
sentry_dsn = os.getenv("SENTRY_DSN")
sentry_env = "production" if db == "PROD" else "staging"
if sentry_dsn:
sentry_sdk.init(
dsn=sentry_dsn,
traces_sample_rate=1.0,
profiles_sample_rate=1.0,
environment=sentry_env,
)
setup_sentry()

# load api keys
EDEN_API_KEY = SecretStr(os.getenv("EDEN_API_KEY", ""))
Expand Down
51 changes: 14 additions & 37 deletions eve/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@
import os
import threading
import json
import asyncio
import modal
from fastapi import FastAPI, Depends, BackgroundTasks, Request
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader, HTTPBearer
from ably import AblyRealtime
from apscheduler.schedulers.background import BackgroundScheduler
from pathlib import Path
from contextlib import asynccontextmanager

from eve import auth
from eve.api.helpers import load_existing_triggers
from eve.postprocessing import (
generate_lora_thumbnails,
cancel_stuck_tasks,
download_nsfw_models,
run_nsfw_detection
run_nsfw_detection,
)
from eve.api.handlers import (
handle_create,
Expand Down Expand Up @@ -63,28 +59,17 @@
# FastAPI setup
@asynccontextmanager
async def lifespan(app: FastAPI):
from eve.api.handlers import handle_chat

# Startup
watch_thread = threading.Thread(target=convert_tasks2_to_tasks3, daemon=True)
watch_thread.start()
app.state.watch_thread = watch_thread

app.state.ably_client = AblyRealtime(
key=os.getenv("ABLY_PUBLISHER_KEY"),
)

# Load existing triggers
await load_existing_triggers(scheduler, app.state.ably_client, handle_chat)

yield
# Shutdown
if hasattr(app.state, "watch_thread"):
app.state.watch_thread.join(timeout=5)
if hasattr(app.state, "scheduler"):
app.state.scheduler.shutdown(wait=True)
if hasattr(app.state, "ably_client"):
await app.state.ably_client.close()
try:
yield
finally:
if hasattr(app.state, "watch_thread"):
app.state.watch_thread.join(timeout=5)
if hasattr(app.state, "scheduler"):
app.state.scheduler.shutdown(wait=True)


web_app = FastAPI(lifespan=lifespan)
Expand Down Expand Up @@ -118,7 +103,7 @@ async def replicate_webhook(request: Request):
# Get raw body for signature verification
body = await request.body()
print(body)

# Parse JSON body
try:
data = json.loads(body)
Expand Down Expand Up @@ -147,7 +132,7 @@ async def chat(
background_tasks: BackgroundTasks,
_: dict = Depends(auth.authenticate_admin),
):
return await handle_chat(request, background_tasks, web_app.state.ably_client)
return await handle_chat(request, background_tasks)


@web_app.post("/chat/stream")
Expand Down Expand Up @@ -177,7 +162,7 @@ async def deployment_delete(
async def trigger_create(
request: CreateTriggerRequest, _: dict = Depends(auth.authenticate_admin)
):
return await handle_trigger_create(request, scheduler, web_app.state.ably_client)
return await handle_trigger_create(request, scheduler)


@web_app.post("/triggers/delete")
Expand All @@ -204,18 +189,13 @@ async def trigger_delete(
.env({"DB": db, "MODAL_SERVE": os.getenv("MODAL_SERVE")})
.apt_install("git", "libmagic1", "ffmpeg", "wget")
.pip_install_from_pyproject(str(root_dir / "pyproject.toml"))
.pip_install(
"numpy<2.0",
"torch==2.0.1",
"torchvision",
"transformers",
"Pillow"
)
.pip_install("numpy<2.0", "torch==2.0.1", "torchvision", "transformers", "Pillow")
.run_commands(["playwright install"])
.run_function(download_nsfw_models)
.copy_local_dir(str(workflows_dir), "/workflows")
)


@app.function(
image=image,
keep_warm=1,
Expand All @@ -232,10 +212,7 @@ def fastapi_app():


@app.function(
image=image,
concurrency_limit=1,
schedule=modal.Period(minutes=15),
timeout=3600
image=image, concurrency_limit=1, schedule=modal.Period(minutes=15), timeout=3600
)
async def postprocessing():
try:
Expand Down
25 changes: 6 additions & 19 deletions eve/api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import time
from bson import ObjectId
from fastapi import BackgroundTasks, Request
from fastapi import BackgroundTasks
from fastapi.responses import StreamingResponse
from ably import AblyRealtime
from apscheduler.schedulers.background import BackgroundScheduler
Expand All @@ -20,9 +20,9 @@
)
from eve.api.helpers import (
emit_update,
get_update_channel,
serialize_for_json,
setup_chat,
AblyConnectionPool,
)
from eve.deploy import (
create_modal_secrets,
Expand Down Expand Up @@ -70,25 +70,15 @@ async def handle_replicate_webhook(body: dict):
task = Task.from_handler_id(body["id"])
tool = Tool.load(task.tool)
_ = replicate_update_task(
task,
body["status"],
body["error"],
body["output"],
tool.output_handler
task, body["status"], body["error"], body["output"], tool.output_handler
)


async def handle_chat(
request: ChatRequest,
background_tasks: BackgroundTasks,
ably_client: AblyRealtime
request: ChatRequest,
background_tasks: BackgroundTasks,
):
user, agent, thread, tools = await setup_chat(request, background_tasks)
update_channel = (
await get_update_channel(request.update_config, ably_client)
if request.update_config and request.update_config.sub_channel_name
else None
)

async def run_prompt():
try:
Expand Down Expand Up @@ -117,12 +107,11 @@ async def run_prompt():
elif update.type == UpdateType.ERROR:
data["error"] = update.error if hasattr(update, "error") else None

await emit_update(request.update_config, update_channel, data)
await emit_update(request.update_config, data)
except Exception as e:
logger.error("Error in run_prompt", exc_info=True)
await emit_update(
request.update_config,
update_channel,
{"type": "error", "error": str(e)},
)

Expand Down Expand Up @@ -202,7 +191,6 @@ async def handle_deployment_delete(request: DeleteDeploymentRequest):
async def handle_trigger_create(
request: CreateTriggerRequest,
scheduler: BackgroundScheduler,
ably_client: AblyRealtime,
):
from eve.trigger import create_chat_trigger

Expand All @@ -215,7 +203,6 @@ async def handle_trigger_create(
schedule=request.schedule.to_cron_dict(),
update_config=request.update_config,
scheduler=scheduler,
ably_client=ably_client,
trigger_id=trigger_id,
handle_chat_fn=handle_chat,
)
Expand Down
53 changes: 22 additions & 31 deletions eve/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import aiohttp
from bson import ObjectId
from fastapi import BackgroundTasks
from ably import AblyRealtime
from ably import AblyRest
from apscheduler.schedulers.background import BackgroundScheduler
import traceback
import asyncio
from contextlib import asynccontextmanager

from eve.tool import Tool
from eve.user import User
Expand All @@ -21,8 +23,8 @@


async def get_update_channel(
update_config: UpdateConfig, ably_client: AblyRealtime
) -> Optional[AblyRealtime]:
update_config: UpdateConfig, ably_client: AblyRest
) -> Optional[AblyRest]:
return ably_client.channels.get(str(update_config.sub_channel_name))


Expand Down Expand Up @@ -53,20 +55,19 @@ def serialize_for_json(obj):
return obj


async def emit_update(
update_config: UpdateConfig,
update_channel: AblyRealtime,
data: dict
):
if update_config:
if update_config.update_endpoint and update_config.sub_channel_name:
raise ValueError(
"update_endpoint and sub_channel_name cannot be used together"
)
elif update_config.update_endpoint:
await emit_http_update(update_config, data)
elif update_config.sub_channel_name:
await emit_channel_update(update_channel, data)
async def emit_update(update_config: UpdateConfig, data: dict):
if not update_config:
return

if update_config.update_endpoint:
await emit_http_update(update_config, data)
elif update_config.sub_channel_name:
try:
client = AblyRest(os.getenv("ABLY_PUBLISHER_KEY"))
channel = client.channels.get(update_config.sub_channel_name)
await channel.publish(update_config.sub_channel_name, data)
except Exception as e:
logger.error(f"Failed to publish to Ably: {str(e)}")


async def emit_http_update(update_config: UpdateConfig, data: dict):
Expand All @@ -85,20 +86,8 @@ async def emit_http_update(update_config: UpdateConfig, data: dict):
logger.error(f"Error sending update to endpoint: {str(e)}")


async def emit_channel_update(
update_channel: AblyRealtime,
data: dict
):
try:
await update_channel.publish("update", data)
except Exception as e:
logger.error(f"Failed to publish to Ably: {str(e)}")


async def load_existing_triggers(
scheduler: BackgroundScheduler,
ably_client: AblyRealtime,
handle_chat_fn
scheduler: BackgroundScheduler, ably_client: AblyRest, handle_chat_fn
):
"""Load all existing triggers from the database and add them to the scheduler"""
from ..trigger import create_chat_trigger
Expand All @@ -116,7 +105,9 @@ async def load_existing_triggers(
agent_id=str(trigger.agent),
message=trigger.message,
schedule=trigger.schedule,
update_config=UpdateConfig(**trigger.update_config) if trigger.update_config else None,
update_config=UpdateConfig(**trigger.update_config)
if trigger.update_config
else None,
scheduler=scheduler,
ably_client=ably_client,
trigger_id=trigger.trigger_id,
Expand Down
Loading