diff --git a/comfyui.py b/comfyui.py index e947a13..425fb55 100644 --- a/comfyui.py +++ b/comfyui.py @@ -65,6 +65,14 @@ test_all = True if os.getenv("TEST_ALL") else False skip_tests = os.getenv("SKIP_TESTS") +print("========================================") +print(f"db: {db}") +print(f"workspace: {workspace_name}") +print(f"test_workflows: {test_workflows}") +print(f"test_all: {test_all}") +print(f"skip_tests: {skip_tests}") +print("========================================") + def install_comfyui(): snapshot = json.load(open("/root/workspace/snapshot.json", 'r')) comfyui_commit_sha = snapshot["comfyui"] @@ -218,6 +226,7 @@ def download_files(force_redownload=False): .env({"TEST_ALL": os.getenv("TEST_ALL")}) .apt_install("git", "git-lfs", "libgl1-mesa-glx", "libglib2.0-0", "libmagic1", "ffmpeg", "libegl1") .pip_install_from_pyproject(str(root_dir / "pyproject.toml")) + .pip_install("diffusers==0.31.0") .env({"WORKSPACE": workspace_name}) .copy_local_file(f"{root_workflows_folder}/workspaces/{workspace_name}/snapshot.json", "/root/workspace/snapshot.json") .copy_local_file(f"{root_workflows_folder}/workspaces/{workspace_name}/downloads.json", "/root/workspace/downloads.json") @@ -328,6 +337,9 @@ def test_workflows(self): if not all([w in workflow_names for w in test_workflows]): raise Exception(f"One or more invalid workflows found: {', '.join(test_workflows)}") workflow_names = test_workflows + print(f"====> Running tests for subset of workflows: {' | '.join(workflow_names)}") + else: + print(f"====> Running tests for all workflows: {' | '.join(workflow_names)}") if not workflow_names: raise Exception("No workflows found!") @@ -338,7 +350,7 @@ def test_workflows(self): tests = glob.glob(f"/root/workspace/workflows/{workflow}/test*.json") else: tests = [f"/root/workspace/workflows/{workflow}/test.json"] - print("Running tests: ", tests) + print(f"====> Running tests for {workflow}: ", tests) for test in tests: tool = Tool.from_yaml(f"/root/workspace/workflows/{workflow}/api.yaml") if tool.status == "inactive": @@ -347,10 +359,12 @@ def test_workflows(self): test_args = json.loads(open(test, "r").read()) test_args = tool.prepare_args(test_args) test_name = f"{workflow}_{os.path.basename(test)}" - print(f"Running test: {test_name}") + print(f"====> Running test: {test_name}") t1 = time.time() result = self._execute(workflow, test_args) result = eden_utils.upload_result(result) + result = eden_utils.prepare_result(result) + print(f"====> Final media url: {result}") t2 = time.time() results[test_name] = result results["_performance"][test_name] = t2 - t1 @@ -460,20 +474,22 @@ def _inject_embedding_mentions_sdxl(self, text, embedding_trigger, embeddings_fi lora_prompt = f"{reference}, {lora_prompt}" return user_prompt, lora_prompt - + def _inject_embedding_mentions_flux(self, text, embedding_trigger, lora_trigger_text): - pattern = r'(<{0}>|<{1}>|{0}|{1})'.format( - re.escape(embedding_trigger), - re.escape(embedding_trigger.lower()) - ) - text = re.sub(pattern, lora_trigger_text, text, flags=re.IGNORECASE) - text = re.sub(r'()', lora_trigger_text, text, flags=re.IGNORECASE) - - if lora_trigger_text not in text: # Make sure the concept is always triggered: + if not embedding_trigger: # Handles both None and empty string + text = re.sub(r'()', lora_trigger_text, text, flags=re.IGNORECASE) + else: + pattern = r'(<{0}>|<{1}>|{0}|{1})'.format( + re.escape(embedding_trigger), + re.escape(embedding_trigger.lower()) + ) + text = re.sub(pattern, lora_trigger_text, text, flags=re.IGNORECASE) + text = re.sub(r'()', lora_trigger_text, text, flags=re.IGNORECASE) + + if lora_trigger_text not in text: text = f"{lora_trigger_text}, {text}" return text - def _transport_lora_flux(self, lora_url: str): loras_folder = "/root/models/loras" @@ -613,7 +629,7 @@ def _validate_comfyui_args(self, workflow, tool): raise Exception(f"Remap parameter {key} is missing original choices: {choices}") def _inject_args_into_workflow(self, workflow, tool, args): - + base_model = "unknown" # Helper function to validate and normalize URLs def validate_url(url): if not isinstance(url, str): @@ -622,16 +638,18 @@ def validate_url(url): url = 'https://' + url return url - pprint(args) + print("===== Injecting comfyui args into workflow =====") + pprint(args) - embedding_trigger = None - lora_trigger_text = None + embedding_triggers = {"lora": None, "lora2": None} + lora_trigger_texts = {"lora": None, "lora2": None} # download and transport files for key, param in tool.model.model_fields.items(): metadata = param.json_schema_extra or {} file_type = metadata.get('file_type') is_array = metadata.get('is_array') + print(f"Parsing {key}, param: {param}") if file_type and any(t in ["image", "video", "audio"] for t in file_type.split("|")): if not args.get(key): @@ -648,22 +666,21 @@ def validate_url(url): elif file_type == "lora": lora_id = args.get(key) - print("LORA ID", lora_id) + if not lora_id: args[key] = None - args["lora_strength"] = 0 - print("REMOVE LORA") + args[f"{key}_strength"] = 0 + print(f"DISABLING {key}") continue - - print("LORA ID", lora_id) - print(type(lora_id)) + + print(f"Found {key} LORA ID: ", lora_id) models = get_collection("models3") lora = models.find_one({"_id": ObjectId(lora_id)}) - print("found lora", lora) + #print("found lora:\n", lora) if not lora: - raise Exception(f"Lora {lora_id} not found") + raise Exception(f"Lora {key} with id: {lora_id} not found!") base_model = lora.get("base_model") lora_url = lora.get("checkpoint") @@ -683,18 +700,13 @@ def validate_url(url): lora_filename, embeddings_filename, embedding_trigger, lora_mode = self._transport_lora_sdxl(lora_url) elif base_model == "flux-dev": lora_filename = self._transport_lora_flux(lora_url) - embedding_trigger = lora.get("args", {}).get("name") - lora_trigger_text = lora.get("lora_trigger_text") + embedding_triggers[key] = lora.get("args", {}).get("name") + try: + lora_trigger_texts[key] = lora.get("lora_trigger_text") + except: # old flux LoRA's: + lora_trigger_texts[key] = lora.get("args", {}).get("caption_prefix") args[key] = lora_filename - args["use_lora"] = True - print("lora filename", lora_filename) - - # inject args - # comfyui_map = { - # param.name: param.comfyui - # for param in tool_.parameters if param.comfyui - # } for key, comfyui in tool.comfyui_map.items(): @@ -706,22 +718,29 @@ def validate_url(url): continue # if there's a lora, replace mentions with embedding name - if key == "prompt" and embedding_trigger: - lora_strength = args.get("lora_strength", 0.5) - if base_model == "flux-dev": - print("INJECTING LORA TRIGGER TEXT", lora_trigger_text) - value = self._inject_embedding_mentions_flux(value, embedding_trigger, lora_trigger_text) - print("INJECTED LORA TRIGGER TEXT", value) + if key == "prompt": + if "flux" in base_model: + for lora_key in ["lora", "lora2"]: + if args.get(f"use_{lora_key}", False): + lora_strength = args.get(f"{lora_key}_strength", 0.7) + value = self._inject_embedding_mentions_flux( + value, + embedding_triggers[lora_key], + lora_trigger_texts[lora_key] + ) + print(f"====> INJECTED {lora_key} TRIGGER TEXT", value) elif base_model == "sdxl": - no_token_prompt, value = self._inject_embedding_mentions_sdxl(value, embedding_trigger, embeddings_filename, lora_mode, lora_strength) - - if "no_token_prompt" in args: - no_token_mapping = next((comfy_param for key, comfy_param in tool.comfyui_map.items() if key == "no_token_prompt"), None) - if no_token_mapping: - print("Updating no_token_prompt for SDXL: ", no_token_prompt) - workflow[str(no_token_mapping.node_id)][no_token_mapping.field][no_token_mapping.subfield] = no_token_prompt - - print("prompt updated:", value) + if embedding_trigger: + lora_strength = args.get("lora_strength", 0.7) + no_token_prompt, value = self._inject_embedding_mentions_sdxl(value, embedding_trigger, embeddings_filename, lora_mode, lora_strength) + + if "no_token_prompt" in args: + no_token_mapping = next((comfy_param for key, comfy_param in tool.comfyui_map.items() if key == "no_token_prompt"), None) + if no_token_mapping: + print("Updating no_token_prompt for SDXL: ", no_token_prompt) + workflow[str(no_token_mapping.node_id)][no_token_mapping.field][no_token_mapping.subfield] = no_token_prompt + + print("====> Final updated prompt for workflow: ", value) if comfyui.preprocessing is not None: if comfyui.preprocessing == "csv": diff --git a/eve/agent.py b/eve/agent.py index 34f283d..f7de278 100644 --- a/eve/agent.py +++ b/eve/agent.py @@ -1,10 +1,11 @@ import os import yaml +import time import json import traceback from pathlib import Path from bson import ObjectId -from typing import Optional, Literal, Any, Dict, List +from typing import Optional, Literal, Any, Dict, List, ClassVar from dotenv import dotenv_values from pydantic import SecretStr, Field from pydantic.json_schema import SkipJsonSchema @@ -15,6 +16,8 @@ from .user import User, Manna from .models import Model +CHECK_INTERVAL = 30 # how often to check cached agents for updates + default_presets_flux = { "flux_dev_lora": {}, "runway": {}, @@ -44,24 +47,23 @@ class Agent(User): name: str description: str instructions: str - # models: Optional[Dict[str, ObjectId]] = None model: Optional[ObjectId] = None test_args: Optional[List[Dict[str, Any]]] = None tools: Optional[Dict[str, Dict]] = None tools_cache: SkipJsonSchema[Optional[Dict[str, Tool]]] = Field(None, exclude=True) - + last_check: ClassVar[Dict[str, float]] = {} # seconds + def __init__(self, **data): if isinstance(data.get('owner'), str): data['owner'] = ObjectId(data['owner']) - # if data.get('models'): - # data['models'] = {k: ObjectId(v) if isinstance(v, str) else v for k, v in data['models'].items()} # Load environment variables into secrets dictionary + db = os.getenv("DB") env_dir = Path(__file__).parent / "agents" - env_vars = dotenv_values(f"{str(env_dir)}/{data['username']}/.env") - data['secrets'] = {key: SecretStr(value) for key, value in env_vars.items()} + env_vars = dotenv_values(f"{str(env_dir)}/{db.lower()}/{data['username']}/.env") + data['secrets'] = {key: SecretStr(value) for key, value in env_vars.items()} super().__init__(**data) - + @classmethod def convert_from_yaml(cls, schema: dict, file_path: str = None) -> dict: """ @@ -109,9 +111,11 @@ def from_yaml(cls, file_path, cache=False): @classmethod def from_mongo(cls, document_id, cache=False): if cache: - if document_id not in _agent_cache: - _agent_cache[str(document_id)] = super().from_mongo(document_id) - return _agent_cache[str(document_id)] + id = str(document_id) + if id not in _agent_cache: + _agent_cache[id] = super().from_mongo(document_id) + cls._check_for_updates(id, document_id) + return _agent_cache[id] else: return super().from_mongo(document_id) @@ -120,6 +124,7 @@ def load(cls, username, cache=False): if cache: if username not in _agent_cache: _agent_cache[username] = super().load(username=username) + cls._check_for_updates(username, _agent_cache[username].id) return _agent_cache[username] else: return super().load(username=username) @@ -184,7 +189,7 @@ def _setup_tools(cls, schema: dict) -> dict: return schema - def get_tools(self,cache=False): + def get_tools(self, cache=False): if not hasattr(self, "tools") or not self.tools: self.tools = {} @@ -204,23 +209,18 @@ def get_tools(self,cache=False): def get_tool(self, tool_name, cache=False): return self.get_tools(cache=cache)[tool_name] + @classmethod + def _check_for_updates(cls, cache_key: str, agent_id: ObjectId): + """Check if agent needs to be updated based on updatedAt field""" + current_time = time.time() + last_check = cls.last_check.get(cache_key, 0) -def get_agents_from_api_files(root_dir: str = None, agents: List[str] = None, include_inactive: bool = False) -> Dict[str, Agent]: - """Get all agents inside a directory""" - - api_files = get_api_files(root_dir, include_inactive) - - all_agents = { - key: Agent.from_yaml(api_file) - for key, api_file in api_files.items() - } - - if agents: - agents = {k: v for k, v in all_agents.items() if k in agents} - else: - agents = all_agents - - return agents + if current_time - last_check >= CHECK_INTERVAL: + cls.last_check[cache_key] = current_time + collection = get_collection(cls.collection_name) + db_agent = collection.find_one({"_id": agent_id}) + if db_agent and db_agent.get("updatedAt") != _agent_cache[cache_key].updatedAt: + _agent_cache[cache_key].reload() def get_agents_from_mongo(agents: List[str] = None, include_inactive: bool = False) -> Dict[str, Agent]: @@ -243,10 +243,11 @@ def get_agents_from_mongo(agents: List[str] = None, include_inactive: bool = Fal return agents -def get_api_files(root_dir: str = None, include_inactive: bool = False) -> List[str]: + +def get_api_files(root_dir: str = None) -> List[str]: """Get all agent directories inside a directory""" - env = os.getenv("DB") + db = os.getenv("DB").lower() if root_dir: root_dirs = [root_dir] @@ -254,22 +255,16 @@ def get_api_files(root_dir: str = None, include_inactive: bool = False) -> List[ eve_root = os.path.dirname(os.path.abspath(__file__)) root_dirs = [ os.path.join(eve_root, agents_dir) - for agents_dir in [f"agents/{env}"] + for agents_dir in [f"agents/{db}"] ] api_files = {} for root_dir in root_dirs: for root, _, files in os.walk(root_dir): if "api.yaml" in files and "test.json" in files: - api_file = os.path.join(root, "api.yaml") - with open(api_file, 'r') as f: - schema = yaml.safe_load(f) - if schema.get("status") == "inactive" and not include_inactive: - continue - key = schema.get("key", os.path.relpath(root).split("/")[-1]) - if key in api_files: - raise ValueError(f"Duplicate agent {key} found.") - api_files[key] = os.path.join(os.path.relpath(root), "api.yaml") + api_path = os.path.join(root, "api.yaml") + key = os.path.relpath(root).split("/")[-1] + api_files[key] = api_path return api_files diff --git a/eve/agents/prod/abraham/api.yaml b/eve/agents/prod/abraham/api.yaml index 828f1db..0aa13de 100644 --- a/eve/agents/prod/abraham/api.yaml +++ b/eve/agents/prod/abraham/api.yaml @@ -45,3 +45,7 @@ clients: enabled: true telegram: enabled: true + +deployments: + - discord + - telegram \ No newline at end of file diff --git a/eve/agents/prod/eve/api.yaml b/eve/agents/prod/eve/api.yaml index 2e3b2cb..8425573 100644 --- a/eve/agents/prod/eve/api.yaml +++ b/eve/agents/prod/eve/api.yaml @@ -44,6 +44,8 @@ tools: lora_trainer: flux_trainer: news: + websearch: + weather: stable_audio: musicgen: audio_split_stems: diff --git a/eve/agents/stage/eve/api.yaml b/eve/agents/stage/eve/api.yaml index 8d2e5b5..2b5f9a7 100644 --- a/eve/agents/stage/eve/api.yaml +++ b/eve/agents/stage/eve/api.yaml @@ -44,6 +44,7 @@ tools: lora_trainer: flux_trainer: news: + websearch: stable_audio: musicgen: audio_split_stems: diff --git a/eve/api.py b/eve/api.py index 909d4c9..a06002b 100644 --- a/eve/api.py +++ b/eve/api.py @@ -31,6 +31,7 @@ from eve.mongo import serialize_document from eve.agent import Agent from eve.user import User +from eve.task import Task from eve import deploy # Config and logging setup @@ -68,6 +69,9 @@ class TaskRequest(BaseModel): args: dict user_id: str +class CancelRequest(BaseModel): + task_id: str + user_id: str class UpdateConfig(BaseModel): sub_channel_name: Optional[str] = None @@ -133,10 +137,32 @@ async def setup_chat( @web_app.post("/create") -async def task_admin(request: TaskRequest, _: dict = Depends(auth.authenticate_admin)): +async def task_admin( + request: TaskRequest, + _: dict = Depends(auth.authenticate_admin) +): result = await handle_task(request.tool, request.user_id, request.args) - return serialize_document(result.model_dump()) - + return serialize_document(result.model_dump(by_alias=True)) + + +async def handle_cancel(task_id: str, user_id: str): + task = Task.from_mongo(task_id) + assert str(task.user) == user_id, "Task user does not match user_id" + if task.status in ["completed", "failed", "cancelled"]: + return {"status": task.status} + tool = Tool.load(key=task.tool) + tool.cancel(task) + return {"status": task.status} + + +@web_app.post("/cancel") +async def cancel( + request: CancelRequest, + _: dict = Depends(auth.authenticate_admin) +): + result = await handle_cancel(request.task_id, request.user_id) + return result + @web_app.post("/chat") async def handle_chat( @@ -273,7 +299,7 @@ async def deploy_handler( if request.credentials: create_modal_secrets( request.credentials, - f"{request.agent_key}-client-secrets", + f"{request.agent_key}-secrets", ) if request.command == DeployCommand.DEPLOY: @@ -283,7 +309,7 @@ async def deploy_handler( "message": f"Deployed {request.platform.value} client", } elif request.command == DeployCommand.STOP: - stop_client(request.agent_key, request.platform.value) + stop_client(request.agent_key, request.platform.value, db) return { "status": "success", "message": f"Stopped {request.platform.value} client", @@ -312,6 +338,7 @@ async def deploy_handler( .env({"DB": db, "MODAL_SERVE": os.getenv("MODAL_SERVE")}) .apt_install("git", "libmagic1", "ffmpeg", "wget") .pip_install_from_pyproject(str(root_dir / "pyproject.toml")) + .run_commands(["playwright install"]) .copy_local_dir(str(workflows_dir), "/workflows") ) diff --git a/eve/cli/agent_cli.py b/eve/cli/agent_cli.py index ef1494c..3e8f987 100644 --- a/eve/cli/agent_cli.py +++ b/eve/cli/agent_cli.py @@ -25,7 +25,7 @@ def update(db: str, names: tuple): load_env(db) - api_files = get_api_files(include_inactive=True) + api_files = get_api_files() agents_order = {agent: index for index, agent in enumerate(api_agents_order)} if names: diff --git a/eve/cli/chat_cli.py b/eve/cli/chat_cli.py index 1b505a7..42b8502 100644 --- a/eve/cli/chat_cli.py +++ b/eve/cli/chat_cli.py @@ -16,15 +16,6 @@ from ..agent import Agent from ..auth import get_my_eden_user -# def preprocess_message(message): -# metadata_pattern = r"\{.*?\}" -# attachments_pattern = r"\[.*?\]" -# attachments_match = re.search(attachments_pattern, message) -# attachments = json.loads(attachments_match.group(0)) if attachments_match else [] -# clean_message = re.sub(metadata_pattern, "", message) -# clean_message = re.sub(attachments_pattern, "", clean_message).strip() -# return clean_message, attachments - async def async_chat(agent_name, new_thread=True, debug=False): if not debug: diff --git a/eve/cli/deploy_cli.py b/eve/cli/deploy_cli.py index d9112ea..bbf5ec7 100644 --- a/eve/cli/deploy_cli.py +++ b/eve/cli/deploy_cli.py @@ -46,7 +46,7 @@ def prepare_client_file(file_path: str, agent_key: str, env: str) -> str: # Replace the static secret name with the dynamic one modified_content = content.replace( 'modal.Secret.from_name("client-secrets")', - f'modal.Secret.from_name("{agent_key}-client-secrets-{env}")', + f'modal.Secret.from_name("{agent_key}-secrets-{env}")', ) # Fix pyproject.toml path to use absolute path @@ -75,7 +75,7 @@ def create_secrets(agent_key: str, secrets_dict: dict, env: str): "modal", "secret", "create", - f"{agent_key}-client-secrets-{env}", + f"{agent_key}-secrets-{env}", ] for key, value in secrets_dict.items(): if value is not None: @@ -92,7 +92,7 @@ def deploy_client(agent_key: str, client_name: str, env: str): try: # Create a temporary modified version of the client file temp_file = prepare_client_file(str(client_path), agent_key, env) - app_name = f"{agent_key}-client-{client_name}-{env}" + app_name = f"{agent_key}-{client_name}-{env}" # Deploy using the temporary file subprocess.run( diff --git a/eve/cli/tool_cli.py b/eve/cli/tool_cli.py index 9ff8aef..e3c6c2f 100644 --- a/eve/cli/tool_cli.py +++ b/eve/cli/tool_cli.py @@ -75,7 +75,7 @@ def update(db: str, names: tuple): load_env(db) - api_files = get_api_files(include_inactive=True) + api_files = get_api_files() tools_order = {t: index for index, t in enumerate(api_tools_order)} if names: diff --git a/eve/clients/discord/client.py b/eve/clients/discord/client.py index 88c0583..7c2cd45 100644 --- a/eve/clients/discord/client.py +++ b/eve/clients/discord/client.py @@ -4,6 +4,7 @@ import aiohttp import argparse import discord +import traceback from discord.ext import commands from dotenv import load_dotenv from ably import AblyRealtime @@ -62,7 +63,7 @@ def __init__( self.channel = None # Track message IDs - self.pending_messages = {} + # self.pending_messages = {} self.typing_tasks = {} # {channel_id: asyncio.Task} async def setup_ably(self): @@ -108,8 +109,9 @@ async def async_callback(message): try: original_message = await channel.fetch_message(int(message_id)) reference = original_message.to_reference() - except: + except Exception as e: print(f"Could not fetch original message {message_id}") + traceback.print_exc() if update_type == UpdateType.START_PROMPT: await self.start_typing(channel) @@ -136,6 +138,7 @@ async def async_callback(message): except Exception as e: print(f"Error processing update: {e}") + traceback.print_exc() await self.channel.subscribe(async_callback) print(f"Subscribed to Ably channel: {self.channel_name}") diff --git a/eve/clients/telegram/client.py b/eve/clients/telegram/client.py index ea63e18..b448fd9 100644 --- a/eve/clients/telegram/client.py +++ b/eve/clients/telegram/client.py @@ -3,8 +3,6 @@ import re from ably import AblyRealtime import aiohttp - -# import logging from dotenv import load_dotenv from telegram import Update from telegram.constants import ChatAction @@ -16,12 +14,14 @@ filters, Application, ) +import asyncio +from ... import load_env from ...clients import common -from ...llm import UpdateType -from ...eden_utils import prepare_result from ...agent import Agent +from ...llm import UpdateType from ...user import User +from ...eden_utils import prepare_result from ...models import ClientType @@ -94,7 +94,10 @@ def replace_bot_mentions(message_text: str, bot_username: str, replacement: str) async def send_response( - message_type: str, chat_id: int, response: list, context: ContextTypes.DEFAULT_TYPE + message_type: str, + chat_id: int, + response: list, + context: ContextTypes.DEFAULT_TYPE ): """ Send messages, photos, or videos based on the type of response. @@ -115,13 +118,21 @@ async def send_response( class EdenTG: - def __init__(self, token: str, agent: Agent, db: str = "STAGE"): + def __init__( + self, + token: str, + agent: Agent, + local: bool = False + ): self.token = token self.agent = agent - self.db = db - self.tools = agent.get_tools() # get_tools_from_mongo(db=self.db) + self.tools = agent.get_tools() self.known_users = {} self.known_threads = {} + if local: + self.api_url = "http://localhost:8000" + else: + self.api_url = os.getenv(f"EDEN_API_URL") self.channel_name = common.get_ably_channel_name( agent.name, ClientType.TELEGRAM ) @@ -129,6 +140,8 @@ def __init__(self, token: str, agent: Agent, db: str = "STAGE"): # Don't initialize Ably here - we'll do it in setup_ably self.ably_client = None self.channel = None + + self.typing_tasks = {} async def initialize(self, application): """Initialize the bot including Ably setup""" @@ -144,6 +157,15 @@ def __del__(self): if hasattr(self, "ably_client") and self.ably_client: self.ably_client.close() + async def _typing_loop(self, chat_id: int, application: Application): + """Keep sending typing action until stopped""" + try: + while True: + await application.bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING) + await asyncio.sleep(5) # Telegram typing status expires after ~5 seconds + except asyncio.CancelledError: + pass + async def setup_ably(self, application): """Initialize Ably client and subscribe to updates""" @@ -166,7 +188,11 @@ async def async_callback(message): print(f"Processing update type: {update_type} for chat: {telegram_chat_id}") if update_type == UpdateType.START_PROMPT: - pass + # Start continuous typing + if telegram_chat_id not in self.typing_tasks: + self.typing_tasks[telegram_chat_id] = asyncio.create_task( + self._typing_loop(telegram_chat_id, application) + ) elif update_type == UpdateType.ERROR: error_msg = data.get("error", "Unknown error occurred") @@ -182,10 +208,10 @@ async def async_callback(message): ) elif update_type == UpdateType.TOOL_COMPLETE: + print(f"Tool complete: {data}") result = data.get("result", {}) - result["result"] = prepare_result(result["result"], db=self.db) + result["result"] = prepare_result(result["result"]) url = result["result"][0]["output"][0]["url"] - # Determine if it's a video or image video_extensions = (".mp4", ".avi", ".mov", ".mkv", ".webm") if any(url.lower().endswith(ext) for ext in video_extensions): @@ -197,6 +223,12 @@ async def async_callback(message): chat_id=telegram_chat_id, photo=url ) + elif update_type == UpdateType.END_PROMPT: + # Stop typing + if telegram_chat_id in self.typing_tasks: + self.typing_tasks[telegram_chat_id].cancel() + del self.typing_tasks[telegram_chat_id] + # Subscribe using the async callback await self.channel.subscribe(async_callback) print(f"Subscribed to Ably channel: {self.channel_name}") @@ -205,7 +237,7 @@ async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ Handler for the /start command. """ - await update.message.reply_text("Hello! I am your asynchronous bot.") + await update.message.reply_text(f"Hello! I am {self.agent.name}.") async def echo(self, update: Update, context: ContextTypes.DEFAULT_TYPE): """ @@ -231,19 +263,23 @@ async def echo(self, update: Update, context: ContextTypes.DEFAULT_TYPE): else None ) + force_reply = False + if is_bot_mentioned or is_replied_to_bot or is_direct_message: + force_reply = True + # Lookup thread thread_key = f"telegram-{chat_id}" if thread_key not in self.known_threads: self.known_threads[thread_key] = self.agent.request_thread( - key=thread_key, - db=self.db, + key=thread_key ) thread = self.known_threads[thread_key] # Lookup user if user_id not in self.known_users: self.known_users[user_id] = User.from_telegram( - user_id, username, db=self.db + user_id, + username ) user = self.known_users[user_id] @@ -270,14 +306,13 @@ async def echo(self, update: Update, context: ContextTypes.DEFAULT_TYPE): message_text, me_bot.username, self.agent.name ) - await context.bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING) - # Make API request api_url = os.getenv("EDEN_API_URL") request_data = { "user_id": str(user.id), "agent_id": str(self.agent.id), "thread_id": str(thread.id), + "force_reply": force_reply, "user_message": { "content": cleaned_text, "name": username, @@ -309,19 +344,22 @@ async def echo(self, update: Update, context: ContextTypes.DEFAULT_TYPE): return -def start(env: str, db: str = "STAGE") -> None: +def start( + env: str, + local: bool = False +) -> None: print("Starting Telegram client...") load_dotenv(env) agent_name = os.getenv("EDEN_AGENT_USERNAME") - agent = Agent.load(agent_name, db=db) + agent = Agent.load(agent_name) bot_token = os.getenv("CLIENT_TELEGRAM_TOKEN") if not bot_token: raise ValueError("CLIENT_TELEGRAM_TOKEN not found in environment variables") application = ApplicationBuilder().token(bot_token).build() - bot = EdenTG(bot_token, agent, db=db) + bot = EdenTG(bot_token, agent, local=local) # Setup handlers application.add_handler(CommandHandler("start", bot.start)) @@ -341,6 +379,6 @@ async def post_init(application: Application) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser(description="Eden Telegram Bot") parser.add_argument("--env", help="Path to the .env file to load", default=".env") - parser.add_argument("--db", help="Database to use", default="STAGE") + parser.add_argument("--local", help="Run locally", action="store_true") args = parser.parse_args() - start(args.env, args.db) + start(args.env, args.local) diff --git a/eve/deploy.py b/eve/deploy.py index dc27f63..1fef10a 100644 --- a/eve/deploy.py +++ b/eve/deploy.py @@ -91,7 +91,7 @@ def modify_client_file(file_path: str, agent_key: str) -> None: # Replace the static secret name with the dynamic one modified_content = content.replace( 'modal.Secret.from_name("client-secrets")', - f'modal.Secret.from_name("{agent_key}-client-secrets")', + f'modal.Secret.from_name("{agent_key}-secrets")', ) # Fix pyproject.toml path to use absolute path @@ -123,13 +123,13 @@ def deploy_client(agent_key: str, client_name: str): raise Exception(f"Client modal file not found: {client_path}") -def stop_client(agent_key: str, client_name: str): +def stop_client(agent_key: str, client_name: str, db: str): subprocess.run( [ "modal", "app", "stop", - f"{agent_key}-client-{client_name}", + f"{agent_key}-{client_name}-{db}", "-e", DEPLOYMENT_ENV_NAME, ], diff --git a/eve/mongo.py b/eve/mongo.py index 0c15245..236589b 100644 --- a/eve/mongo.py +++ b/eve/mongo.py @@ -151,22 +151,27 @@ def save(self, upsert_filter=None, **kwargs): """ Save the current state of the model to the database. """ + self.updatedAt = datetime.now(timezone.utc) schema = self.model_dump(by_alias=True) self.model_validate(schema) schema = self.convert_to_mongo(schema) schema.update(kwargs) - self.updatedAt = datetime.now(timezone.utc) - + filter = upsert_filter or {"_id": self.id or ObjectId()} collection = self.get_collection() + if self.id or filter: if filter: schema.pop("_id", None) - result = collection.find_one_and_replace( + created_at = schema.pop("createdAt", None) + result = collection.find_one_and_update( filter, - schema, + { + "$set": schema, + "$setOnInsert": {"createdAt": created_at}, + }, upsert=True, - return_document=True, # Returns the document after changes + return_document=True, ) self.id = result["_id"] else: diff --git a/eve/thread.py b/eve/thread.py index 7e9b268..39f096c 100644 --- a/eve/thread.py +++ b/eve/thread.py @@ -162,14 +162,15 @@ def get_result(self, schema, truncate_images=False): if self.status == "completed": result["result"] = prepare_result(self.result) - outputs = [ - o.get("url") + file_outputs = [ + o["url"] for r in result.get("result", []) for o in r.get("output", []) + if isinstance(o, dict) and o.get("url") ] - outputs = [ + file_outputs = [ o - for o in outputs + for o in file_outputs if o and o.endswith((".jpg", ".png", ".webp", ".mp4", ".webm")) ] try: @@ -184,7 +185,7 @@ def get_result(self, schema, truncate_images=False): os.path.join("/tmp/eden_file_cache/", url.split("/")[-1]), overwrite=False, ) - for url in outputs + for url in file_outputs ] if schema == "anthropic": diff --git a/eve/tool.py b/eve/tool.py index cc6f38e..366d466 100644 --- a/eve/tool.py +++ b/eve/tool.py @@ -88,10 +88,10 @@ class Tool(Document, ABC): @classmethod def _get_schema(cls, key, from_yaml=False) -> dict: """Get schema for a tool, with detailed performance logging.""" - + if from_yaml: # YAML path - api_files = get_api_files(include_inactive=True) + api_files = get_api_files() if key not in api_files: raise ValueError(f"Tool {key} not found") @@ -123,7 +123,7 @@ def get_sub_class( parent_tool = schema.get("parent_tool") if parent_tool: - parent_schema = cls._get_schema(parent_tool, from_yaml) + parent_schema = cls._get_schema(parent_tool, from_yaml=from_yaml) handler = parent_schema.get("handler") else: handler = schema.get("handler") @@ -475,13 +475,16 @@ def get_tools_from_api_files( ) -> Dict[str, Tool]: """Get all tools inside a directory""" - api_files = get_api_files(root_dir, include_inactive) + api_files = get_api_files(root_dir) tools = { key: _tool_cache.get(api_file) or Tool.from_yaml(api_file, cache=cache) for key, api_file in api_files.items() if tools is None or key in tools } + if not include_inactive: + tools = {k: v for k, v in tools.items() if v.status != "inactive"} + return tools @@ -518,8 +521,8 @@ def get_tools_from_mongo( return found_tools -def get_api_files(root_dir: str = None, include_inactive: bool = False) -> List[str]: - """Get all tool directories inside a directory""" +def get_api_files(root_dir: str = None) -> List[str]: + """Get all api.yaml files inside a directory""" if root_dir: root_dirs = [root_dir] @@ -535,17 +538,9 @@ def get_api_files(root_dir: str = None, include_inactive: bool = False) -> List[ for root, _, files in os.walk(root_dir): if "api.yaml" in files and "test.json" in files: api_file = os.path.join(root, "api.yaml") - with open(api_file, "r") as f: - schema = yaml.safe_load(f) - if schema.get("status") == "inactive" and not include_inactive: - continue - key = schema.get("key", os.path.relpath(root).split("/")[-1]) - if key in api_files: - raise ValueError(f"Duplicate tool {key} found.") - api_files[key] = os.path.join(os.path.relpath(root), "api.yaml") + api_files[os.path.relpath(root).split("/")[-1]] = api_file return api_files - # Tool cache for fetching commonly used tools _tool_cache: Dict[str, Dict[str, Tool]] = {} diff --git a/eve/tools/comfyui_tool.py b/eve/tools/comfyui_tool.py index dc4a42f..d3433c5 100644 --- a/eve/tools/comfyui_tool.py +++ b/eve/tools/comfyui_tool.py @@ -50,7 +50,6 @@ async def async_run(self, args: Dict): async def async_start_task(self, task: Task): db = os.getenv("DB") cls = modal.Cls.lookup( - # f"comfyui-{self.workspace}-{task.db}", f"comfyui-{self.workspace}-{db}", "ComfyUI", environment_name="main" diff --git a/eve/tools/flux_dev_lora/test.json b/eve/tools/flux_dev_lora/test.json index 922b35a..7eac964 100644 --- a/eve/tools/flux_dev_lora/test.json +++ b/eve/tools/flux_dev_lora/test.json @@ -1,6 +1,6 @@ { - "prompt": "hey Banny, a yellow cartoon bananaman at the podium giving a speech, as president of the United States. In the background are military generals adorned with medals.", - "lora": "67611d1943808b38016c62c3", + "prompt": "Verdelis driving a racecar in the desert with beautiful mountain sunset in the background", + "lora": "6778ac7a54e12f2b03fd7abb", "lora_strength": 1.0, "aspect_ratio": "16:9" } \ No newline at end of file diff --git a/eve/tools/local_tool.py b/eve/tools/local_tool.py index f1e0ddb..2279fbd 100644 --- a/eve/tools/local_tool.py +++ b/eve/tools/local_tool.py @@ -14,6 +14,7 @@ def __init__(self, *args, **kwargs): @Tool.handle_run async def async_run(self, args: Dict): + print("running", self.parent_tool or self.key, args) result = await handlers[self.parent_tool or self.key](args) return result diff --git a/eve/tools/media_utils/audio_video_combine/handler.py b/eve/tools/media_utils/audio_video_combine/handler.py index 1e96028..56b17e8 100644 --- a/eve/tools/media_utils/audio_video_combine/handler.py +++ b/eve/tools/media_utils/audio_video_combine/handler.py @@ -3,7 +3,7 @@ # from ... import eden_utils -async def handler(args: dict, db: str): +async def handler(args: dict): from .... import eden_utils video_url = args.get("video") diff --git a/eve/tools/media_utils/image_concat/handler.py b/eve/tools/media_utils/image_concat/handler.py index 5e5316a..8bbad05 100644 --- a/eve/tools/media_utils/image_concat/handler.py +++ b/eve/tools/media_utils/image_concat/handler.py @@ -2,7 +2,7 @@ # from ... import eden_utils -async def handler(args: dict, db: str): +async def handler(args: dict): from .... import eden_utils image_urls = args.get("images") diff --git a/eve/tools/media_utils/image_crop/handler.py b/eve/tools/media_utils/image_crop/handler.py index 1a684e7..369c9f1 100644 --- a/eve/tools/media_utils/image_crop/handler.py +++ b/eve/tools/media_utils/image_crop/handler.py @@ -3,7 +3,7 @@ # from ... import eden_utils -async def handler(args: dict, db: str): +async def handler(args: dict): from .... import eden_utils image_url = args.get("image") diff --git a/eve/tools/media_utils/time_remapping/handler.py b/eve/tools/media_utils/time_remapping/handler.py index 83fa41a..b9b6636 100644 --- a/eve/tools/media_utils/time_remapping/handler.py +++ b/eve/tools/media_utils/time_remapping/handler.py @@ -46,7 +46,7 @@ def smart_frame_selection(orig_size, target_size): middle_indices = np.linspace(1, orig_size-2, target_size-2).round().astype(int) return np.concatenate([[0], middle_indices, [orig_size-1]]) -async def handler(args: dict, db: str): +async def handler(args: dict): # Get parameters video_url = args["video"] target_fps = args.get("target_fps") diff --git a/eve/tools/media_utils/video_concat/handler.py b/eve/tools/media_utils/video_concat/handler.py index 21d0e4b..43e9c5c 100644 --- a/eve/tools/media_utils/video_concat/handler.py +++ b/eve/tools/media_utils/video_concat/handler.py @@ -6,7 +6,7 @@ # bug: if some videos are silent but others have sound, the concatenated video will have no sound -async def handler(args: dict, db: str): +async def handler(args: dict): from .... import eden_utils video_urls = args.get("videos") diff --git a/eve/tools/modal_tool.py b/eve/tools/modal_tool.py index e847f67..f2f6a7c 100644 --- a/eve/tools/modal_tool.py +++ b/eve/tools/modal_tool.py @@ -1,4 +1,5 @@ import modal +import os from typing import Dict from ..task import Task @@ -8,8 +9,9 @@ class ModalTool(Tool): @Tool.handle_run async def async_run(self, args: Dict): + db = os.getenv("DB", "STAGE").upper() func = modal.Function.lookup( - "modal_tools", + f"modal-tools-{db}", "run", environment_name="main" ) @@ -18,8 +20,9 @@ async def async_run(self, args: Dict): @Tool.handle_start_task async def async_start_task(self, task: Task): + db = os.getenv("DB", "STAGE").upper() func = modal.Function.lookup( - "modal_tools", + f"modal-tools-{db}", "run_task", environment_name="main" ) diff --git a/eve/tools/runway/handler.py b/eve/tools/runway/handler.py index 03290c3..ef075a0 100644 --- a/eve/tools/runway/handler.py +++ b/eve/tools/runway/handler.py @@ -1,6 +1,6 @@ -import time +import asyncio import runwayml -from runwayml import RunwayML +from runwayml import AsyncRunwayML """ Todo: @@ -10,15 +10,12 @@ async def handler(args: dict): - client = RunwayML() - - - + client = AsyncRunwayML() try: ratio = "1280:768" if args["ratio"] == "16:9" else "768:1280" - task = client.image_to_video.create( + task = await client.image_to_video.create( model='gen3a_turbo', prompt_image=args["prompt_image"], prompt_text=args["prompt_text"][:512], @@ -51,12 +48,14 @@ async def handler(args: dict): task_id = task.id print(task_id) - time.sleep(10) - task = client.tasks.retrieve(task_id) + # time.sleep(10) + await asyncio.sleep(10) + task = await client.tasks.retrieve(task_id) while task.status not in ['SUCCEEDED', 'FAILED']: print("status", task.status) - time.sleep(10) - task = client.tasks.retrieve(task_id) + # time.sleep(10) + await asyncio.sleep(10) + task = await client.tasks.retrieve(task_id) # TODO: callback for running state diff --git a/eve/tools/tool_handlers.py b/eve/tools/tool_handlers.py index d8c6fd1..a8def99 100644 --- a/eve/tools/tool_handlers.py +++ b/eve/tools/tool_handlers.py @@ -6,6 +6,8 @@ from .media_utils.video_concat.handler import handler as video_concat from .wallet.send_eth.handler import handler as send_eth +from .websearch.handler import handler as websearch +from .weather.handler import handler as weather from .media_utils.time_remapping.handler import handler as time_remapping from .twitter.get_tweets.handler import handler as get_tweets @@ -38,6 +40,7 @@ "hedra": hedra, "elevenlabs": elevenlabs, "memegen": memegen, - + "websearch": websearch, "send_eth": send_eth, + "weather": weather, } diff --git a/eve/tools/weather/api.yaml b/eve/tools/weather/api.yaml new file mode 100644 index 0000000..7b9c83a --- /dev/null +++ b/eve/tools/weather/api.yaml @@ -0,0 +1,17 @@ +name: Weather +description: Get the weather for a given location +cost_estimate: 1 +output_type: string +status: prod +visible: true +parameters: + lat: + type: float + label: Latitude + description: The latitude of the location to get the weather for + required: true + lon: + type: float + label: Longitude + description: The longitude of the location to get the weather for + required: true \ No newline at end of file diff --git a/eve/tools/weather/handler.py b/eve/tools/weather/handler.py new file mode 100644 index 0000000..688610d --- /dev/null +++ b/eve/tools/weather/handler.py @@ -0,0 +1,26 @@ +import json +import requests + +async def handler(args: dict): + lat = args["lat"] + lon = args["lon"] + + points_url = f"https://api.weather.gov/points/{lat},{lon}" + # Provide a descriptive User-Agent per NOAA policy + headers = {'User-Agent': 'MyForecastApp (contact@example.com)'} + + # Step 1: Get the forecast endpoint + points_resp = requests.get(points_url, headers=headers) + points_data = points_resp.json() + + # Step 2: Use the "forecast" or "forecastHourly" property to get actual data + forecast_url = points_data["properties"]["forecast"] + forecast_resp = requests.get(forecast_url, headers=headers) + forecast_data = forecast_resp.json() + + output = forecast_data["properties"] + + return { + "output": json.dumps(output) + } + diff --git a/eve/tools/weather/test.json b/eve/tools/weather/test.json new file mode 100644 index 0000000..fca3628 --- /dev/null +++ b/eve/tools/weather/test.json @@ -0,0 +1,4 @@ +{ + "lat": 33.355415, + "lon": -115.723984 +} \ No newline at end of file diff --git a/eve/tools/websearch/handler.py b/eve/tools/websearch/handler.py index 2939f54..5777efb 100644 --- a/eve/tools/websearch/handler.py +++ b/eve/tools/websearch/handler.py @@ -14,7 +14,7 @@ async def safe_evaluate(page, script: str, default_value: Any) -> Tuple[Any, str except Exception as e: return default_value, str(e) -async def handler(args: dict, env: str = None) -> Dict[str, str]: +async def handler(args: dict) -> Dict[str, str]: """ Handler function for the websearch tool that scrapes content from specified URLs. @@ -28,9 +28,7 @@ async def handler(args: dict, env: str = None) -> Dict[str, str]: Returns: Dict[str, str]: Dictionary containing the scraped content and any errors """ - url = args.get('url') - if not url: - raise ValueError("URL parameter is required") + url = args["url"] # Get configurable limits from args with defaults max_links = int(args.get('max_links', 15)) @@ -147,8 +145,7 @@ async def handler(args: dict, env: str = None) -> Dict[str, str]: await browser.close() # Format output with error reporting - output = f""" -# Page Analysis: {url} + output = f"""# Page Analysis: {url} ## Title {page_content['title']} diff --git a/eve/tools/websearch/test_locally.py b/eve/tools/websearch/test_locally.py deleted file mode 100644 index a0537c6..0000000 --- a/eve/tools/websearch/test_locally.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -from handler import handler -import json -from datetime import datetime -import os - -async def test_website(url: str) -> None: - """Test the handler with a single website and print results.""" - print(f"\n{'='*80}") - print(f"Testing URL: {url}") - print(f"{'='*80}") - - try: - result = await handler({"url": url}) - print(result["output"]) - except Exception as e: - print(f"Error testing {url}: {str(e)}") - -async def run_tests(): - """Run tests on various types of websites.""" - results_dir = "websearch_results" - os.makedirs(results_dir, exist_ok=True) - - test_urls = [ - "https://news.ycombinator.com", - "https://huggingface.co/papers", - "https://www.reddit.com/r/StableDiffusion/", - ] - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - results = {} - - for url in test_urls: - try: - result = await handler({"url": url}) - results[url] = { - "status": "success", - "output": result["output"] - } - except Exception as e: - results[url] = { - "status": "error", - "error": str(e) - } - - output_file = os.path.join(results_dir, f"websearch_test_{timestamp}.json") - with open(output_file, 'w', encoding='utf-8') as f: - json.dump(results, f, indent=2) - - await test_website(url) - print("\nWaiting 2 seconds before next test...") - await asyncio.sleep(2) - - print(f"\nTest results saved to: {output_file}") - -if __name__ == "__main__": - print("Starting websearch tool tests...") - asyncio.run(run_tests()) \ No newline at end of file diff --git a/modal_tool.py b/modal_tool.py index f4d185e..e5b9c23 100644 --- a/modal_tool.py +++ b/modal_tool.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import modal @@ -5,21 +6,17 @@ from eve import eden_utils from eve.tools.tool_handlers import handlers +db = os.getenv("DB", "STAGE").upper() +if db not in ["PROD", "STAGE"]: + raise Exception(f"Invalid environment: {db}. Must be PROD or STAGE") +app_name = "modal-tools-prod" if db == "PROD" else "modal-tools-stage" app = modal.App( - name="modal_tools", + name=app_name, secrets=[ - modal.Secret.from_name("s3-credentials"), - modal.Secret.from_name("mongo-credentials"), - modal.Secret.from_name("replicate"), - modal.Secret.from_name("openai"), - modal.Secret.from_name("anthropic"), - modal.Secret.from_name("elevenlabs"), - modal.Secret.from_name("hedra"), - modal.Secret.from_name("newsapi"), - modal.Secret.from_name("runway"), - modal.Secret.from_name("sentry"), - ], + modal.Secret.from_name("eve-secrets"), + modal.Secret.from_name(f"eve-secrets-{db}"), + ], ) root_dir = Path(__file__).parent diff --git a/pyproject.toml b/pyproject.toml index 438810e..affcebb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "ably>=2.0.7", "colorama>=0.4.6", "web3<7.6.1", + "playwright<1.49", ] [build-system] @@ -63,3 +64,6 @@ dev-dependencies = [ [project.scripts] eve = "eve.cli:cli" + +[tool.rye.scripts] +post-install = "playwright install" diff --git a/requirements-dev.lock b/requirements-dev.lock index ccdb25d..1301343 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -173,6 +173,8 @@ googleapis-common-protos==1.66.0 # via google-api-core # via grpc-google-iam-v1 # via grpcio-status +greenlet==3.1.1 + # via playwright grpc-google-iam-v1==0.13.1 # via google-cloud-resource-manager grpcio==1.68.1 @@ -290,6 +292,8 @@ pillow==11.0.0 # via imageio platformdirs==4.3.6 # via virtualenv +playwright==1.18.1 + # via eve pluggy==1.5.0 # via pytest pre-commit==4.0.1 @@ -346,6 +350,7 @@ pydub==0.25.1 # via eve pyee==11.1.1 # via ably + # via playwright pygments==2.18.0 # via rich pyhumps==3.8.0 @@ -490,6 +495,7 @@ websockets==12.0 # via ably # via elevenlabs # via eve + # via playwright # via web3 yarl==1.18.3 # via aiohttp diff --git a/requirements.lock b/requirements.lock index d675832..8165085 100644 --- a/requirements.lock +++ b/requirements.lock @@ -166,6 +166,8 @@ googleapis-common-protos==1.66.0 # via google-api-core # via grpc-google-iam-v1 # via grpcio-status +greenlet==3.1.1 + # via playwright grpc-google-iam-v1==0.13.1 # via google-cloud-resource-manager grpcio==1.68.1 @@ -272,6 +274,8 @@ parsimonious==0.9.0 pillow==11.0.0 # via eve # via imageio +playwright==1.18.1 + # via eve proglog==0.1.10 # via moviepy propcache==0.2.1 @@ -323,6 +327,7 @@ pydub==0.25.1 # via eve pyee==11.1.1 # via ably + # via playwright pygments==2.18.0 # via rich pyhumps==3.8.0 @@ -458,6 +463,7 @@ websockets==12.0 # via ably # via elevenlabs # via eve + # via playwright # via web3 yarl==1.18.3 # via aiohttp diff --git a/tests/test_client.py b/tests/test_client.py index 36081b1..4d9b8ff 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -21,9 +21,11 @@ def run_create(server_url): } } response = requests.post(server_url+"/create", json=request, headers=headers) + print("GO!!!") print(response) print("Status Code:", response.status_code) print(json.dumps(response.json(), indent=2)) + print("done...") def run_chat(server_url): @@ -49,7 +51,7 @@ def test_client(): stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - time.sleep(3) + time.sleep(5) server_url = "http://localhost:8000" print("server_url", server_url) @@ -58,7 +60,7 @@ def test_client(): run_create(server_url) print("\nRunning chat test...") - run_chat(server_url) + # run_chat(server_url) except KeyboardInterrupt: print("\nShutting down...") @@ -69,3 +71,6 @@ def test_client(): server.terminate() server.wait() + +if __name__ == "__main__": + test_client() \ No newline at end of file diff --git a/tests/test_tools.py b/tests/test_tools.py index b577672..1ec36ac 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -22,7 +22,7 @@ async def async_run_all_tools( ): """Test multiple tools with their test args""" # Get tools from either yaml files or mongo - tool_dict = get_tools_from_api_files(tools=tools, include_inactive=True) if yaml else get_tools_from_mongo(tools=tools) + tool_dict = get_tools_from_api_files(tools=tools) if yaml else get_tools_from_mongo(tools=tools) # Create and run tasks tasks = [async_run_tool(tool, api, mock) for tool in tool_dict.values()]