diff --git a/comfyui.py b/comfyui.py index fb92de5..e947a13 100644 --- a/comfyui.py +++ b/comfyui.py @@ -47,6 +47,7 @@ from eve.tool import Tool from eve.mongo import get_collection from eve.task import task_handler_method +from eve.s3 import get_full_url GPUs = { "A100": modal.gpu.A100(), @@ -264,13 +265,13 @@ def _start(self, port=8188): t2 = time.time() self.launch_time = t2 - t1 - def _execute(self, workflow_name: str, args: dict, db: str): + def _execute(self, workflow_name: str, args: dict): try: tool_path = f"/root/workspace/workflows/{workflow_name}" tool = Tool.from_yaml(f"{tool_path}/api.yaml") workflow = json.load(open(f"{tool_path}/workflow_api.json", 'r')) self._validate_comfyui_args(workflow, tool) - workflow = self._inject_args_into_workflow(workflow, tool, args, db=db) + workflow = self._inject_args_into_workflow(workflow, tool, args) prompt_id = self._queue_prompt(workflow)['prompt_id'] outputs = self._get_outputs(prompt_id) output = outputs[str(tool.comfyui_output_node_id)] @@ -290,14 +291,14 @@ def _execute(self, workflow_name: str, args: dict, db: str): raise @modal.method() - def run(self, tool_key: str, args: dict, db: str): - result = self._execute(tool_key, args, db=db) - return eden_utils.upload_result(result, db=db) + def run(self, tool_key: str, args: dict): + result = self._execute(tool_key, args) + return eden_utils.upload_result(result) @modal.method() @task_handler_method - async def run_task(self, tool_key: str, args: dict, db: str): - return self._execute(tool_key, args, db=db) + async def run_task(self, tool_key: str, args: dict): + return self._execute(tool_key, args) @modal.enter() def enter(self): @@ -348,8 +349,8 @@ def test_workflows(self): test_name = f"{workflow}_{os.path.basename(test)}" print(f"Running test: {test_name}") t1 = time.time() - result = self._execute(workflow, test_args, db="STAGE") - result = eden_utils.upload_result(result, db="STAGE") + result = self._execute(workflow, test_args) + result = eden_utils.upload_result(result) t2 = time.time() results[test_name] = result results["_performance"][test_name] = t2 - t1 @@ -460,16 +461,16 @@ def _inject_embedding_mentions_sdxl(self, text, embedding_trigger, embeddings_fi return user_prompt, lora_prompt - def _inject_embedding_mentions_flux(self, text, embedding_trigger, caption_prefix): + def _inject_embedding_mentions_flux(self, text, embedding_trigger, lora_trigger_text): pattern = r'(<{0}>|<{1}>|{0}|{1})'.format( re.escape(embedding_trigger), re.escape(embedding_trigger.lower()) ) - text = re.sub(pattern, caption_prefix, text, flags=re.IGNORECASE) - text = re.sub(r'()', caption_prefix, text, flags=re.IGNORECASE) + text = re.sub(pattern, lora_trigger_text, text, flags=re.IGNORECASE) + text = re.sub(r'()', lora_trigger_text, text, flags=re.IGNORECASE) - if caption_prefix not in text: # Make sure the concept is always triggered: - text = f"{caption_prefix}, {text}" + if lora_trigger_text not in text: # Make sure the concept is always triggered: + text = f"{lora_trigger_text}, {text}" return text @@ -611,7 +612,7 @@ def _validate_comfyui_args(self, workflow, tool): if not all(choice in remap.map.keys() for choice in choices): raise Exception(f"Remap parameter {key} is missing original choices: {choices}") - def _inject_args_into_workflow(self, workflow, tool, args, db="STAGE"): + def _inject_args_into_workflow(self, workflow, tool, args): # Helper function to validate and normalize URLs def validate_url(url): @@ -624,7 +625,7 @@ def validate_url(url): pprint(args) embedding_trigger = None - caption_prefix = None + lora_trigger_text = None # download and transport files for key, param in tool.model.model_fields.items(): @@ -653,14 +654,18 @@ def validate_url(url): args["lora_strength"] = 0 print("REMOVE LORA") continue + + print("LORA ID", lora_id) + print(type(lora_id)) - models = get_collection("models", db=db) + models = get_collection("models3") lora = models.find_one({"_id": ObjectId(lora_id)}) - base_model = lora.get("base_model") - print("LORA", lora) + print("found lora", lora) + if not lora: raise Exception(f"Lora {lora_id} not found") + base_model = lora.get("base_model") lora_url = lora.get("checkpoint") #lora_name = lora.get("name") #pretrained_model = lora.get("args").get("sd_model_version") @@ -670,6 +675,8 @@ def validate_url(url): else: print("LORA URL", lora_url) + lora_url = get_full_url(lora_url) + print("lora url", lora_url) print("base model", base_model) if base_model == "sdxl": @@ -677,10 +684,11 @@ def validate_url(url): elif base_model == "flux-dev": lora_filename = self._transport_lora_flux(lora_url) embedding_trigger = lora.get("args", {}).get("name") - caption_prefix = lora.get("args", {}).get("caption_prefix") + lora_trigger_text = lora.get("lora_trigger_text") args[key] = lora_filename - print("lora filename", lora_filename) + args["use_lora"] = True + print("lora filename", lora_filename) # inject args # comfyui_map = { @@ -700,8 +708,10 @@ def validate_url(url): # if there's a lora, replace mentions with embedding name if key == "prompt" and embedding_trigger: lora_strength = args.get("lora_strength", 0.5) - if base_model == "flux_dev": - value = self._inject_embedding_mentions_flux(value, embedding_trigger, caption_prefix) + if base_model == "flux-dev": + print("INJECTING LORA TRIGGER TEXT", lora_trigger_text) + value = self._inject_embedding_mentions_flux(value, embedding_trigger, lora_trigger_text) + print("INJECTED LORA TRIGGER TEXT", value) elif base_model == "sdxl": no_token_prompt, value = self._inject_embedding_mentions_sdxl(value, embedding_trigger, embeddings_filename, lora_mode, lora_strength) diff --git a/eve/__init__.py b/eve/__init__.py index 933a064..f3f7feb 100644 --- a/eve/__init__.py +++ b/eve/__init__.py @@ -5,28 +5,63 @@ import os home_dir = str(Path.home()) -eve_path = os.path.join(home_dir, ".eve") -env_path = ".env" -# First try ENV_PATH from environment -env_path_override = os.getenv("ENV_PATH") -if env_path_override and os.path.exists(env_path_override): - load_dotenv(env_path_override) -# Then try ~/.eve -if os.path.exists(eve_path): - load_dotenv(eve_path, override=True) +EDEN_API_KEY = None -# Finally fall back to .env -if os.path.exists(env_path): - load_dotenv(env_path, override=True) -# start sentry -sentry_dsn = os.getenv("SENTRY_DSN") -sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0) +def load_env(db): + global EDEN_API_KEY -# load api keys -EDEN_API_KEY = SecretStr(os.getenv("EDEN_API_KEY", "")) + db = db.upper() + if db not in ["STAGE", "PROD"]: + raise ValueError(f"Invalid database: {db}") + + os.environ["DB"] = db -if not EDEN_API_KEY: - print("WARNING: EDEN_API_KEY is not set") + # First try ~/.eve + stage = db == "STAGE" + env_file = ".env.STAGE" if stage else ".env" + eve_file = ".eve.STAGE" if stage else ".eve" + eve_path = os.path.join(home_dir, eve_file) + + if os.path.exists(eve_path): + load_dotenv(eve_path, override=True) + + # Then try ENV_PATH from environment or .env + env_path_override = os.getenv("ENV_PATH") + if env_path_override and os.path.exists(env_path_override): + load_dotenv(env_path_override, override=True) + elif os.path.exists(env_file): + load_dotenv(env_file, override=True) + + # start sentry + sentry_dsn = os.getenv("SENTRY_DSN") + sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0) + + # load api keys + EDEN_API_KEY = SecretStr(os.getenv("EDEN_API_KEY", "")) + + if not EDEN_API_KEY: + print("WARNING: EDEN_API_KEY is not set") + + verify_env() + + +def verify_env(): + AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") + AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") + AWS_REGION_NAME = os.getenv("AWS_REGION_NAME") + MONGO_URI = os.getenv("MONGO_URI") + + if not all([AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME]): + print( + "WARNING: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_REGION_NAME must be set in the environment" + ) + + if not MONGO_URI: + print("WARNING: MONGO_URI must be set in the environment") + + +db = os.getenv("DB", "STAGE") +load_env(db) diff --git a/eve/agent.py b/eve/agent.py index 767b098..34f283d 100644 --- a/eve/agent.py +++ b/eve/agent.py @@ -13,6 +13,18 @@ from .tool import Tool from .mongo import Collection, get_collection from .user import User, Manna +from .models import Model + +default_presets_flux = { + "flux_dev_lora": {}, + "runway": {}, + "reel": {}, +} +default_presets_sdxl = { + "txt2img": {}, + "runway": {}, + "reel": {}, +} @Collection("users3") @@ -32,7 +44,8 @@ class Agent(User): name: str description: str instructions: str - models: Optional[Dict[str, ObjectId]] = None + # models: Optional[Dict[str, ObjectId]] = None + model: Optional[ObjectId] = None test_args: Optional[List[Dict[str, Any]]] = None tools: Optional[Dict[str, Dict]] = None @@ -41,8 +54,8 @@ class Agent(User): def __init__(self, **data): if isinstance(data.get('owner'), str): data['owner'] = ObjectId(data['owner']) - if data.get('models'): - data['models'] = {k: ObjectId(v) if isinstance(v, str) else v for k, v in data['models'].items()} + # if data.get('models'): + # data['models'] = {k: ObjectId(v) if isinstance(v, str) else v for k, v in data['models'].items()} # Load environment variables into secrets dictionary env_dir = Path(__file__).parent / "agents" env_vars = dotenv_values(f"{str(env_dir)}/{data['username']}/.env") @@ -61,56 +74,58 @@ def convert_from_yaml(cls, schema: dict, file_path: str = None) -> dict: owner = schema.get('owner') schema["owner"] = ObjectId(owner) if isinstance(owner, str) else owner schema["username"] = schema.get("username") or file_path.split("/")[-2] - schema["tools"] = {k: v or {} for k, v in schema.get("tools", {}).items()} + schema = cls._setup_tools(schema) return schema - + @classmethod - def convert_from_mongo(cls, schema: dict, db="STAGE") -> dict: - schema["tools"] = {k: v or {} for k, v in schema.get("tools", {}).items()} + def convert_from_mongo(cls, schema: dict) -> dict: + schema = cls._setup_tools(schema) return schema - def save(self, db=None, **kwargs): + def save(self, **kwargs): # do not overwrite any username if it already exists - users = get_collection(User.collection_name, db=db) + users = get_collection(User.collection_name) if users.find_one({"username": self.username, "type": "user"}): raise ValueError(f"Username {self.username} already taken") # save user, and create mannas record if it doesn't exist kwargs["featureFlags"] = ["freeTools"] # give agents free tools for now - super().save(db, {"username": self.username, "type": "agent"}, **kwargs) - Manna.load(user=self.id, db=db) + super().save( + upsert_filter={"username": self.username, "type": "agent"}, + **kwargs + ) + Manna.load(user=self.id) # create manna record if it doesn't exist @classmethod - def from_yaml(cls, file_path, db="STAGE", cache=False): + def from_yaml(cls, file_path, cache=False): if cache: if file_path not in _agent_cache: - _agent_cache[file_path] = super().from_yaml(file_path, db=db) + _agent_cache[file_path] = super().from_yaml(file_path) return _agent_cache[file_path] else: - return super().from_yaml(file_path, db=db) + return super().from_yaml(file_path) @classmethod - def from_mongo(cls, document_id, db="STAGE", cache=False): + def from_mongo(cls, document_id, cache=False): if cache: if document_id not in _agent_cache: - _agent_cache[str(document_id)] = super().from_mongo(document_id, db=db) + _agent_cache[str(document_id)] = super().from_mongo(document_id) return _agent_cache[str(document_id)] else: - return super().from_mongo(document_id, db=db) + return super().from_mongo(document_id) @classmethod - def load(cls, username, db=None, cache=False): + def load(cls, username, cache=False): if cache: if username not in _agent_cache: - _agent_cache[username] = super().load(username=username, db=db) + _agent_cache[username] = super().load(username=username) return _agent_cache[username] else: - return super().load(username=username, db=db) + return super().load(username=username) - def request_thread(self, key=None, user=None, db="STAGE"): + def request_thread(self, key=None, user=None): thread = Thread( - db=db, key=key, agent=self.id, user=user, @@ -118,29 +133,77 @@ def request_thread(self, key=None, user=None, db="STAGE"): thread.save() return thread - def get_tools(self, db="STAGE", cache=False): - if not self.tools: - return {} + @classmethod + def _setup_tools(cls, schema: dict) -> dict: + """ + Sets up the agent's tools based on the tools defined in the schema. + If a model (lora) is set, hardcode it into the tools. + """ + tools = schema.get("tools") + if tools: + schema["tools"] = {k: v or {} for k, v in tools.items()} + else: + schema["tools"] = default_presets_flux + if "model" in schema: + model = Model.from_mongo(schema["model"]) + if model.base_model == "flux-dev": + schema["tools"] = default_presets_flux + schema["tools"]["flux_dev_lora"] = { + "name": f"Generate {model.name}", + "description": f"Generate an image of {model.name}", + "parameters": { + "prompt": { + "description": f"The text prompt. Always mention {model.name}." + }, + "lora": { + "default": str(model.id), + "hide_from_agent": True, + }, + "lora_strength": { + "default": 1.0, + "hide_from_agent": True, + } + } + } + schema["tools"]["reel"] = { + "name": f"Generate {model.name}", + "tip": f"Make sure to always include {model.name} in all of the prompts.", + "parameters": { + "lora": { + "default": str(model.id), + "hide_from_agent": True, + }, + "lora_strength": { + "default": 1.0, + "hide_from_agent": True, + } + } + } + elif model.base_model == "sdxl": + schema["tools"] = default_presets_sdxl + + return schema + + def get_tools(self,cache=False): + if not hasattr(self, "tools") or not self.tools: + self.tools = {} + if cache: self.tools_cache = self.tools_cache or {} for k, v in self.tools.items(): if k not in self.tools_cache: - tool = Tool.from_raw_yaml({"parent_tool": k, **v}, db=db) + tool = Tool.from_raw_yaml({"parent_tool": k, **v}) self.tools_cache[k] = tool return self.tools_cache - else: + else: return { - k: Tool.from_raw_yaml({"parent_tool": k, **v}, db=db) + k: Tool.from_raw_yaml({"parent_tool": k, **v}) for k, v in self.tools.items() } - def get_tool(self, tool_name, db="STAGE", cache=False): - return self.get_tools(db=db, cache=cache)[tool_name] + def get_tool(self, tool_name, cache=False): + return self.get_tools(cache=cache)[tool_name] - def get_system_message(self): - system_message = f"{self.description}\n\n{self.instructions}\n\n{generic_instructions}" - return system_message - def get_agents_from_api_files(root_dir: str = None, agents: List[str] = None, include_inactive: bool = False) -> Dict[str, Agent]: """Get all agents inside a directory""" @@ -160,16 +223,16 @@ def get_agents_from_api_files(root_dir: str = None, agents: List[str] = None, in return agents -def get_agents_from_mongo(db: str, agents: List[str] = None, include_inactive: bool = False) -> Dict[str, Agent]: +def get_agents_from_mongo(agents: List[str] = None, include_inactive: bool = False) -> Dict[str, Agent]: """Get all agents from mongo""" filter = {"key": {"$in": agents}} if agents else {} agents = {} - agents_collection = get_collection(Agent.collection_name, db=db) + agents_collection = get_collection(Agent.collection_name) for agent in agents_collection.find(filter): try: - agent = Agent.convert_from_mongo(agent, db=db) - agent = Agent.from_schema(agent, db=db) + agent = Agent.convert_from_mongo(agent) + agent = Agent.from_schema(agent) if agent.status != "inactive" and not include_inactive: if agent.key in agents: raise ValueError(f"Duplicate agent {agent.key} found.") @@ -182,6 +245,8 @@ def get_agents_from_mongo(db: str, agents: List[str] = None, include_inactive: b def get_api_files(root_dir: str = None, include_inactive: bool = False) -> List[str]: """Get all agent directories inside a directory""" + + env = os.getenv("DB") if root_dir: root_dirs = [root_dir] @@ -189,7 +254,7 @@ def get_api_files(root_dir: str = None, include_inactive: bool = False) -> List[ eve_root = os.path.dirname(os.path.abspath(__file__)) root_dirs = [ os.path.join(eve_root, agents_dir) - for agents_dir in ["agents"] + for agents_dir in [f"agents/{env}"] ] api_files = {} @@ -210,11 +275,3 @@ def get_api_files(root_dir: str = None, include_inactive: bool = False) -> List[ # Agent cache for fetching commonly used agents _agent_cache: Dict[str, Dict[str, Agent]] = {} - -generic_instructions = """Follow these additional guidelines: -- If the tool you are using has the "n_samples" parameter, and the user requests for multiple versions of the same thing, set n_samples to the number of images the user desires for that prompt. If they want N > 1 images that have different prompts, then make N separate tool calls with n_samples=1. -- When a lora is set, absolutely make sure to include "" in the prompt to refer to object or person represented by the lora. -- If you get an error using a tool because the user requested an invalid parameter, or omitted a required parameter, ask the user for clarification before trying again. Do *not* try to guess what the user meant. -- If you get an error using a tool because **YOU** made a mistake, do not apologize for the oversight or explain what *you* did wrong, just fix your mistake, and automatically retry the task. -- When returning the final results to the user, do not include *any* text except a markdown link to the image(s) and/or video(s) with the prompt as the text and the media url as the link. DO NOT include any other text, such as the name of the tool used, a summary of the results, the other args, or any other explanations. Just [prompt](url). -- When doing multi-step tasks, present your intermediate results in each message before moving onto the next tool use. For example, if you are asked to create an image and then animate it, make sure to return the image (including the url) to the user (as markdown, like above).""" diff --git a/eve/agents/prod/abraham/api.yaml b/eve/agents/prod/abraham/api.yaml index e94b13d..828f1db 100644 --- a/eve/agents/prod/abraham/api.yaml +++ b/eve/agents/prod/abraham/api.yaml @@ -1,5 +1,5 @@ name: Abraham -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 featureFlags: - freeTools diff --git a/eve/agents/prod/anime/api.yaml b/eve/agents/prod/anime/api.yaml index 2eba377..bd2f3ee 100644 --- a/eve/agents/prod/anime/api.yaml +++ b/eve/agents/prod/anime/api.yaml @@ -1,5 +1,5 @@ name: Eve-san -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 featureFlags: - freeTools diff --git a/eve/agents/prod/banny/api.yaml b/eve/agents/prod/banny/api.yaml index 95237c7..1cbbc0f 100644 --- a/eve/agents/prod/banny/api.yaml +++ b/eve/agents/prod/banny/api.yaml @@ -1,5 +1,5 @@ name: Banny -owner: 6544502e4cd811f27c430b56 +owner: 6526f38042a1043421aa28e6 userImage: https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/405024ab0572704cad07a0d7c22158ef443d2d988490e4f1a538e235d321a9c9.png featureFlags: - freeTools @@ -55,7 +55,7 @@ tools: hide_from_agent: true hide_from_ui: true reel: - tip: Make sure to always include Banny in the prompts. + tip: Make sure to always include Banny in all of the prompts. parameters: lora: default: 6766760643808b38016c64ce diff --git a/eve/agents/prod/beepler/api.yaml b/eve/agents/prod/beepler/api.yaml index 68a784a..01e5ffc 100644 --- a/eve/agents/prod/beepler/api.yaml +++ b/eve/agents/prod/beepler/api.yaml @@ -1,5 +1,5 @@ name: Beepler -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 featureFlags: - freeTools diff --git a/eve/agents/prod/bombay_beach/api.yaml b/eve/agents/prod/bombay_beach/api.yaml index 0b50403..61b4e2f 100644 --- a/eve/agents/prod/bombay_beach/api.yaml +++ b/eve/agents/prod/bombay_beach/api.yaml @@ -1,5 +1,5 @@ name: Bombay Beach -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 featureFlags: - freeTools diff --git a/eve/agents/prod/cyberswami/api.yaml b/eve/agents/prod/cyberswami/api.yaml index ca74ea1..fbb38bf 100644 --- a/eve/agents/prod/cyberswami/api.yaml +++ b/eve/agents/prod/cyberswami/api.yaml @@ -1,5 +1,5 @@ name: CyberSwami -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 featureFlags: - freeTools diff --git a/eve/agents/prod/desci/api.yaml b/eve/agents/prod/desci/api.yaml index 99b94ac..79e0838 100644 --- a/eve/agents/prod/desci/api.yaml +++ b/eve/agents/prod/desci/api.yaml @@ -1,5 +1,5 @@ name: F(desci) -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 userImage: https://edenartlab-stage-data.s3.amazonaws.com/9ff4e40812d0519c592313a504060ee87f99910b2d663f6e55ee86ecabdbfa4e.jpg featureFlags: - freeTools diff --git a/eve/agents/prod/eve/api.yaml b/eve/agents/prod/eve/api.yaml index 106e768..2e3b2cb 100644 --- a/eve/agents/prod/eve/api.yaml +++ b/eve/agents/prod/eve/api.yaml @@ -1,5 +1,5 @@ name: Eve -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 userImage: https://edenartlab-stage-data.s3.amazonaws.com/d158dc1e5c62479489c1c3d119dd211bd56ba86a127359f7476990ec9e081cba.jpg featureFlags: - freeTools diff --git a/eve/agents/prod/example_agent/api.yaml b/eve/agents/prod/example_agent/api.yaml index 9de5bc0..c11e25d 100644 --- a/eve/agents/prod/example_agent/api.yaml +++ b/eve/agents/prod/example_agent/api.yaml @@ -1,5 +1,5 @@ name: Example Agent -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 description: | You are a helpful assistant. diff --git a/eve/agents/prod/photo/api.yaml b/eve/agents/prod/photo/api.yaml index 3dc3efb..dcaa623 100644 --- a/eve/agents/prod/photo/api.yaml +++ b/eve/agents/prod/photo/api.yaml @@ -1,5 +1,5 @@ name: Eve -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 description: | Your name is Eve. You are an expert at using Eden, a generative AI platform that empowers individuals to create and share their unique digital creations. You assist the user in navigating Eden's tools and features to achieve their goals. diff --git a/eve/agents/prod/verdelis/api.yaml b/eve/agents/prod/verdelis/api.yaml index cc63527..b140183 100644 --- a/eve/agents/prod/verdelis/api.yaml +++ b/eve/agents/prod/verdelis/api.yaml @@ -1,5 +1,5 @@ name: Verdelis -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 userImage: https://edenartlab-stage-data.s3.amazonaws.com/41c397a76de66ba54d8fd72e7138423d132d2aa5e741c5a9c6cc4a0090a946ef.jpg description: | diff --git a/eve/agents/prod/vj/api.yaml b/eve/agents/prod/vj/api.yaml index 90cd542..0a590bb 100644 --- a/eve/agents/prod/vj/api.yaml +++ b/eve/agents/prod/vj/api.yaml @@ -1,5 +1,5 @@ name: VJ Eve -owner: 65284b18f8bbb9bff13ebe65 +owner: 6526f38042a1043421aa28e6 description: | Your name is VJ Eve. You are an expert at using Eden, a generative AI platform that empowers individuals to create and share their unique digital creations. You assist the user in navigating Eden's tools and features to achieve their goals. diff --git a/eve/agents/staging/eve/api.yaml b/eve/agents/stage/eve/api.yaml similarity index 100% rename from eve/agents/staging/eve/api.yaml rename to eve/agents/stage/eve/api.yaml diff --git a/eve/agents/staging/eve/test.json b/eve/agents/stage/eve/test.json similarity index 100% rename from eve/agents/staging/eve/test.json rename to eve/agents/stage/eve/test.json diff --git a/eve/api.py b/eve/api.py index bcd9f2e..909d4c9 100644 --- a/eve/api.py +++ b/eve/api.py @@ -90,9 +90,9 @@ class ChatRequest(BaseModel): async def handle_task(tool: str, user_id: str, args: dict = {}) -> dict: - tool = Tool.load(key=tool, db=db) + tool = Tool.load(key=tool) return await tool.async_start_task( - requester_id=user_id, user_id=user_id, args=args, db=db + requester_id=user_id, user_id=user_id, args=args ) @@ -119,14 +119,14 @@ async def setup_chat( except Exception as e: logger.error(f"Failed to create Ably channel: {str(e)}") - user = User.from_mongo(request.user_id, db=db) - agent = Agent.from_mongo(request.agent_id, db=db, cache=True) - tools = agent.get_tools(db=db, cache=True) + user = User.from_mongo(request.user_id) + agent = Agent.from_mongo(request.agent_id, cache=True) + tools = agent.get_tools(cache=True) if request.thread_id: - thread = Thread.from_mongo(request.thread_id, db=db) + thread = Thread.from_mongo(request.thread_id) else: - thread = agent.request_thread(db=db, user=user.id) + thread = agent.request_thread(user=user.id) background_tasks.add_task(async_title_thread, thread, request.user_message) return user, agent, thread, tools, update_channel @@ -150,8 +150,8 @@ async def handle_chat( ) async def run_prompt(): + async for update in async_prompt_thread( - db=db, user=user, agent=agent, thread=thread, @@ -222,7 +222,6 @@ async def stream_chat( async def event_generator(): async for update in async_prompt_thread( - db=db, user=user, agent=agent, thread=thread, diff --git a/eve/auth.py b/eve/auth.py index 9b60d39..630e2c9 100644 --- a/eve/auth.py +++ b/eve/auth.py @@ -17,7 +17,6 @@ api_key_header = APIKeyHeader(name="X-Api-Key", auto_error=False) bearer_scheme = HTTPBearer(auto_error=False) -db = os.getenv("DB", "STAGE") EDEN_ADMIN_KEY = os.getenv("EDEN_ADMIN_KEY") ABRAHAM_ADMIN_KEY = os.getenv("ABRAHAM_ADMIN_KEY") ISSUER_URL = os.getenv("CLERK_ISSUER_URL") @@ -30,14 +29,14 @@ def get_api_keys(): global _api_keys if _api_keys is None: - _api_keys = get_collection("apikeys", db=db) + _api_keys = get_collection("apikeys") return _api_keys def get_users(): global _users if _users is None: - _users = get_collection("users2", db=db) + _users = get_collection("users3") return _users @@ -48,13 +47,13 @@ class UserData(BaseModel): isAdmin: bool = False -def get_my_eden_user(db: str = "STAGE") -> str: +def get_my_eden_user() -> str: """Get the user id for the api key in your env file""" api_key = EDEN_API_KEY api_key = get_api_keys().find_one({"apiKey": api_key.get_secret_value()}) if not api_key: raise HTTPException(status_code=401, detail="API key not found") - user = User.from_mongo(api_key["user"], db=db) + user = User.from_mongo(api_key["user"]) if not user: raise HTTPException(status_code=401, detail="User not found") return user diff --git a/eve/base.py b/eve/base.py index 9e9afa2..144ba02 100644 --- a/eve/base.py +++ b/eve/base.py @@ -1,6 +1,21 @@ import copy -from pydantic import BaseModel, Field, create_model -from typing import Any, Optional, Type, List, Dict, Union, get_origin, get_args, Literal, Tuple +from pydantic import ( + BaseModel, + Field, + create_model +) +from typing import ( + Any, + Optional, + Type, + List, + Dict, + Union, + get_origin, + get_args, + Literal, + Tuple +) from . import eden_utils diff --git a/eve/cli/__init__.py b/eve/cli/__init__.py index 0ad8083..6564a88 100644 --- a/eve/cli/__init__.py +++ b/eve/cli/__init__.py @@ -1,5 +1,4 @@ import click - from .tool_cli import tool from .agent_cli import agent from .chat_cli import chat diff --git a/eve/cli/agent_cli.py b/eve/cli/agent_cli.py index 3e3c3e9..ef1494c 100644 --- a/eve/cli/agent_cli.py +++ b/eve/cli/agent_cli.py @@ -1,6 +1,7 @@ import click import traceback from ..agent import Agent, get_api_files +from .. import load_env api_agents_order = ["eve", "abraham", "banny"] @@ -21,7 +22,8 @@ def agent(): @click.argument("names", nargs=-1, required=False) def update(db: str, names: tuple): """Upload agents to mongo""" - db = db.upper() + + load_env(db) api_files = get_api_files(include_inactive=True) agents_order = {agent: index for index, agent in enumerate(api_agents_order)} @@ -40,7 +42,7 @@ def update(db: str, names: tuple): try: order = agents_order.get(key, len(api_agents_order)) agent = Agent.from_yaml(api_file) - agent.save(db=db, order=order) + agent.save(order=order) click.echo( click.style(f"Updated agent {db}:{key} (order={order})", fg="green") ) diff --git a/eve/cli/chat_cli.py b/eve/cli/chat_cli.py index 0392912..1b505a7 100644 --- a/eve/cli/chat_cli.py +++ b/eve/cli/chat_cli.py @@ -10,12 +10,12 @@ from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn +from .. import load_env from ..llm import async_prompt_thread, UserMessage, UpdateType from ..eden_utils import prepare_result, dump_json from ..agent import Agent from ..auth import get_my_eden_user - # def preprocess_message(message): # metadata_pattern = r"\{.*?\}" # attachments_pattern = r"\[.*?\]" @@ -26,28 +26,20 @@ # return clean_message, attachments -async def async_chat(db, agent_name, new_thread=True, debug=False): - db = db.upper() - +async def async_chat(agent_name, new_thread=True, debug=False): if not debug: logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("anthropic").setLevel(logging.WARNING) - user = get_my_eden_user(db=db) - agent = Agent.load(agent_name, db=db) + user = get_my_eden_user() + agent = Agent.load(agent_name) key = f"cli_{str(agent.name)}_{str(user.id)}" if not new_thread: key += f"_{int(time.time())}" - thread = agent.request_thread(key=key, db=db) - tools = agent.get_tools(db=db) - - print("THE TOOLS ARE", tools.keys()) - from ..eden_utils import dump_json - from pprint import pprint - pprint(tools) - + thread = agent.request_thread(key=key) + tools = agent.get_tools() chat_string = f"Chat with {agent.name}".center(36) console = Console() @@ -61,14 +53,8 @@ async def async_chat(db, agent_name, new_thread=True, debug=False): console.print("[bold yellow]You [dim]→[/dim] ", end="") message_input = input("\033[93m") - # if message_input.lower() == "escape": - # console.print("\n[dim]Goodbye! 👋[/dim]\n") - # break - print() - # content, attachments = preprocess_message(message_input) - metadata_pattern = r"\{.*?\}" attachments_pattern = r"\[.*?\]" attachments_match = re.search(attachments_pattern, message_input) @@ -76,7 +62,6 @@ async def async_chat(db, agent_name, new_thread=True, debug=False): content = re.sub(metadata_pattern, "", message_input) content = re.sub(attachments_pattern, "", content).strip() - with Progress( SpinnerColumn(), TextColumn("[bold cyan]"), @@ -91,7 +76,6 @@ async def async_chat(db, agent_name, new_thread=True, debug=False): sys.stdout = devnull async for update in async_prompt_thread( - db=db, user=user, agent=agent, thread=thread, @@ -112,7 +96,7 @@ async def async_chat(db, agent_name, new_thread=True, debug=False): ) print() elif update.type == UpdateType.TOOL_COMPLETE: - result = prepare_result(update.result.get("result"), db=db) + result = prepare_result(update.result.get("result")) console.print( "[bold cyan]🔧 [dim]" + update.tool_name + "[/dim]" ) @@ -164,8 +148,10 @@ async def async_chat(db, agent_name, new_thread=True, debug=False): def chat(db: str, thread: str, agent: str, debug: bool): """Chat with an agent""" + load_env(db) + try: - asyncio.run(async_chat(db, agent, thread, debug)) + asyncio.run(async_chat(agent, thread, debug)) except Exception as e: click.echo(click.style(f"Failed to chat with {agent}:", fg="red")) click.echo(click.style(f"Error: {str(e)}", fg="red")) diff --git a/eve/cli/deploy_cli.py b/eve/cli/deploy_cli.py index c202c54..d9112ea 100644 --- a/eve/cli/deploy_cli.py +++ b/eve/cli/deploy_cli.py @@ -9,12 +9,10 @@ import tempfile import shutil +from .. import load_env + root_dir = Path(__file__).parent.parent.parent ENV_NAME = "deployments" -db = os.getenv("DB", "STAGE").upper() -if db not in ["PROD", "STAGE"]: - raise Exception(f"Invalid environment: {db}. Must be PROD or STAGE") -stage = "staging" if db == "STAGE" else "prod" def ensure_modal_env_exists(): @@ -36,7 +34,7 @@ def ensure_modal_env_exists(): ) -def prepare_client_file(file_path: str, agent_key: str) -> str: +def prepare_client_file(file_path: str, agent_key: str, env: str) -> str: """Create a temporary copy of the client file with modifications""" with open(file_path, "r") as f: content = f.read() @@ -48,7 +46,7 @@ def prepare_client_file(file_path: str, agent_key: str) -> str: # Replace the static secret name with the dynamic one modified_content = content.replace( 'modal.Secret.from_name("client-secrets")', - f'modal.Secret.from_name("{agent_key}-client-secrets-{stage}")', + f'modal.Secret.from_name("{agent_key}-client-secrets-{env}")', ) # Fix pyproject.toml path to use absolute path @@ -66,7 +64,7 @@ def prepare_client_file(file_path: str, agent_key: str) -> str: return str(temp_file) -def create_secrets(agent_key: str, secrets_dict: dict): +def create_secrets(agent_key: str, secrets_dict: dict, env: str): if not secrets_dict: click.echo(click.style(f"No secrets found for {agent_key}", fg="yellow")) return @@ -77,7 +75,7 @@ def create_secrets(agent_key: str, secrets_dict: dict): "modal", "secret", "create", - f"{agent_key}-client-secrets-{stage}", + f"{agent_key}-client-secrets-{env}", ] for key, value in secrets_dict.items(): if value is not None: @@ -88,13 +86,13 @@ def create_secrets(agent_key: str, secrets_dict: dict): subprocess.run(cmd_parts) -def deploy_client(agent_key: str, client_name: str): +def deploy_client(agent_key: str, client_name: str, env: str): client_path = root_dir / f"eve/clients/{client_name}/modal_client.py" if client_path.exists(): try: # Create a temporary modified version of the client file - temp_file = prepare_client_file(str(client_path), agent_key) - app_name = f"{agent_key}-client-{client_name}-{stage}" + temp_file = prepare_client_file(str(client_path), agent_key, env) + app_name = f"{agent_key}-client-{client_name}-{env}" # Deploy using the temporary file subprocess.run( @@ -122,9 +120,9 @@ def deploy_client(agent_key: str, client_name: str): ) -def get_deployable_agents(): +def get_deployable_agents(env: str): """Find all agents that have both .env and deployments configured""" - agents_dir = root_dir / "eve" / "agents" / stage + agents_dir = root_dir / "eve" / "agents" / env deployable = [] for agent_dir in agents_dir.glob("*"): @@ -152,7 +150,7 @@ def get_deployable_agents(): return deployable -def process_agent(agent_path: Path): +def process_agent(agent_path: Path, env: str): with open(agent_path) as f: agent_config = yaml.safe_load(f) @@ -168,29 +166,37 @@ def process_agent(agent_path: Path): if env_file.exists(): click.echo(click.style(f"Creating secrets for: {agent_key}", fg="green")) client_secrets = dotenv_values(env_file) - create_secrets(agent_key, client_secrets) + create_secrets(agent_key, client_secrets, env) # Deploy each client for deployment in agent_config["deployments"]: click.echo(click.style(f"Deploying client: {deployment}", fg="green")) - deploy_client(agent_key, deployment) + deploy_client(agent_key, deployment, env) @click.command() @click.argument("agent", nargs=1, required=False) @click.option("--all", is_flag=True, help="Deploy all configured agents") -def deploy(agent: str, all: bool): +@click.option( + "--db", + type=click.Choice(["STAGE", "PROD"], case_sensitive=False), + default="STAGE", + help="DB to save against", +) +def deploy(agent: str, all: bool, db: str): """Deploy Modal agents. Use --all to deploy all configured agents.""" try: # Ensure Modal environment exists ensure_modal_env_exists() + load_env(db) + env = "stage" if db == "STAGE" else "prod" if all: - agents = get_deployable_agents() + agents = get_deployable_agents(env) if not agents: click.echo( click.style( - f"No deployable agents found in {stage} environment", + f"No deployable agents found in {env} environment", fg="yellow", ) ) @@ -206,17 +212,17 @@ def deploy(agent: str, all: bool): for agent_name in agents: click.echo(click.style(f"\nProcessing agent: {agent_name}", fg="blue")) agent_path = ( - root_dir / "eve" / "agents" / stage / agent_name / "api.yaml" + root_dir / "eve" / "agents" / env / agent_name / "api.yaml" ) - process_agent(agent_path) + process_agent(agent_path, env) else: if not agent: raise click.UsageError("Please provide an agent name or use --all") - agent_path = root_dir / "eve" / "agents" / stage / agent / "api.yaml" + agent_path = root_dir / "eve" / "agents" / env / agent / "api.yaml" if agent_path.exists(): - process_agent(agent_path) + process_agent(agent_path, env) else: click.echo( click.style( diff --git a/eve/cli/start_cli.py b/eve/cli/start_cli.py index 6744b4e..cc54811 100644 --- a/eve/cli/start_cli.py +++ b/eve/cli/start_cli.py @@ -5,6 +5,7 @@ import multiprocessing from pathlib import Path +from .. import load_env from ..models import ClientType from ..clients.discord.client import start as start_discord from ..clients.telegram.client import start as start_telegram @@ -19,11 +20,11 @@ default="STAGE", help="DB to save against", ) -@click.option( - "--env", - type=click.Path(exists=True, resolve_path=True), - help="Path to environment file", -) +# @click.option( +# "--env", +# type=click.Path(exists=True, resolve_path=True), +# help="Path to environment file", +# ) @click.option( "--platforms", type=click.Choice( @@ -43,16 +44,15 @@ default=False, help="Run locally", ) -def start(agent: str, db: str, env: str, platforms: tuple, local: bool): +def start(agent: str, db: str, platforms: tuple, local: bool): """Start one or more clients from yaml files""" try: - agent_dir = Path(__file__).parent.parent / "agents" / agent + load_env(db) + + agent_dir = Path(__file__).parent.parent / "agents" / db.lower() / agent env_path = agent_dir / ".env" yaml_path = agent_dir / "api.yaml" - db = db.upper() - env_path = env or env_path - clients_to_start = {} if platforms: @@ -82,15 +82,15 @@ def start(agent: str, db: str, env: str, platforms: tuple, local: bool): try: if client_type == ClientType.DISCORD: p = multiprocessing.Process( - target=start_discord, args=(env_path, db, local) + target=start_discord, args=(env_path, local) ) elif client_type == ClientType.TELEGRAM: p = multiprocessing.Process( - target=start_telegram, args=(env_path, db) + target=start_telegram, args=(env_path, local) ) elif client_type == ClientType.FARCASTER: p = multiprocessing.Process( - target=start_farcaster, args=(env_path, db) + target=start_farcaster, args=(env_path, local) ) p.start() @@ -141,11 +141,20 @@ def start(agent: str, db: str, env: str, platforms: tuple, local: bool): default=False, help="Enable auto-reload on code changes", ) -def api(host: str, port: int, reload: bool): +@click.option( + "--db", + type=click.Choice(["STAGE", "PROD"], case_sensitive=False), + default="STAGE", + help="DB to save against", +) +def api(host: str, port: int, reload: bool, db: str): """Start the Eve API server""" import uvicorn + import os - click.echo(click.style(f"Starting API server on {host}:{port}...", fg="blue")) + load_env(db) + + click.echo(click.style(f"Starting API server on {host}:{port} with DB={db}...", fg="blue")) # Adjusted the import path to look one directory up uvicorn.run( diff --git a/eve/cli/tool_cli.py b/eve/cli/tool_cli.py index e80e030..9ff8aef 100644 --- a/eve/cli/tool_cli.py +++ b/eve/cli/tool_cli.py @@ -3,6 +3,7 @@ import asyncio import traceback +from .. import load_env from ..eden_utils import save_test_results, prepare_result, dump_json, CLICK_COLORS from ..auth import get_my_eden_user from ..tool import Tool, get_tools_from_mongo, get_tools_from_api_files, get_api_files @@ -71,7 +72,9 @@ def tool(): @click.argument("names", nargs=-1, required=False) def update(db: str, names: tuple): """Upload tools to mongo""" - db = db.upper() + + load_env(db) + api_files = get_api_files(include_inactive=True) tools_order = {t: index for index, t in enumerate(api_tools_order)} @@ -88,8 +91,8 @@ def update(db: str, names: tuple): for key, api_file in api_files.items(): try: order = tools_order.get(key, len(api_tools_order)) - tool2 = Tool.from_yaml(api_file) - tool2.save(db=db, order=order) + tool_ = Tool.from_yaml(api_file) + tool_.save(order=order) click.echo( click.style(f"Updated tool {db}:{key} (order={order})", fg="green") ) @@ -122,8 +125,9 @@ def update(db: str, names: tuple): def run(ctx, tool: str, db: str): """Create with a tool. Args are passed as --key=value or --key value""" - db = db.upper() - tool = Tool.load(key=tool, db=db) + load_env(db) + + tool = Tool.load(key=tool) # Parse args args = dict() @@ -143,7 +147,7 @@ def run(ctx, tool: str, db: str): args[key] = True i += 1 - result = tool.run(args, db=db) + result = tool.run(args) color = random.choice(CLICK_COLORS) if result.get("error"): click.echo( @@ -154,7 +158,7 @@ def run(ctx, tool: str, db: str): ) ) else: - result = prepare_result(result, db=db) + result = prepare_result(result) click.echo( click.style(f"\nResult for {tool.key}: {dump_json(result)}", fg=color) ) @@ -185,27 +189,33 @@ def run(ctx, tool: str, db: str): @click.option("--mock", is_flag=True, default=False, help="Mock test results") @click.argument("tools", nargs=-1, required=False) def test( - tools: tuple, yaml: bool, db: str, api: bool, parallel: bool, save: bool, mock: bool + tools: tuple, + yaml: bool, + db: str, + api: bool, + parallel: bool, + save: bool, + mock: bool ): """Test multiple tools with their test args""" - db = db.upper() + load_env(db) - async def async_test_tool(tool, api, db): + async def async_test_tool(tool, api): color = random.choice(CLICK_COLORS) click.echo(click.style(f"\n\nTesting {tool.key}:", fg=color, bold=True)) click.echo(click.style(f"Args: {dump_json(tool.test_args)}", fg=color)) if api: - user = get_my_eden_user(db=db) + user = get_my_eden_user() # decorate this task = await tool.async_start_task( - user.id, user.id, tool.test_args, db=db, mock=mock + user.id, user.id, tool.test_args, mock=mock ) result = await tool.async_wait(task) else: - result = await tool.async_run(tool.test_args, db=db, mock=mock) + result = await tool.async_run(tool.test_args, mock=mock) if isinstance(result, dict) and result.get("error"): click.echo( @@ -216,15 +226,15 @@ async def async_test_tool(tool, api, db): ) ) else: - result = prepare_result(result, db=db) + result = prepare_result(result) click.echo( click.style(f"\nResult for {tool.key}: {dump_json(result)}", fg=color) ) return result - async def async_run_tests(tools, api, db, parallel): - tasks = [async_test_tool(tool, api, db) for tool in tools.values()] + async def async_run_tests(tools, api, parallel): + tasks = [async_test_tool(tool, api) for tool in tools.values()] if parallel: results = await asyncio.gather(*tasks) else: @@ -234,7 +244,7 @@ async def async_run_tests(tools, api, db, parallel): if yaml: all_tools = get_tools_from_api_files(tools=tools) else: - all_tools = get_tools_from_mongo(db=db, tools=tools) + all_tools = get_tools_from_mongo(tools=tools) if not tools: confirm = click.confirm( @@ -250,7 +260,7 @@ async def async_run_tests(tools, api, db, parallel): if not confirm: all_tools.pop("flux_trainer") - results = asyncio.run(async_run_tests(all_tools, api, db, parallel)) + results = asyncio.run(async_run_tests(all_tools, api, parallel)) if save and results: save_test_results(all_tools, results) diff --git a/eve/cli/upload_cli.py b/eve/cli/upload_cli.py index 7ccd66e..b8d3001 100644 --- a/eve/cli/upload_cli.py +++ b/eve/cli/upload_cli.py @@ -1,5 +1,6 @@ import click from ..s3 import upload_file +from .. import load_env @click.command() @click.option( @@ -11,11 +12,12 @@ @click.argument("files", nargs=-1, required=False) def upload(db: str, files: tuple): """Upload agents to mongo""" - db = db.upper() + + load_env(db) for file in files: try: - result = upload_file(file, db=db) + result = upload_file(file) url = result[0] click.echo( click.style( diff --git a/eve/clients/discord/client.py b/eve/clients/discord/client.py index e3795df..88c0583 100644 --- a/eve/clients/discord/client.py +++ b/eve/clients/discord/client.py @@ -8,12 +8,13 @@ from dotenv import load_dotenv from ably import AblyRealtime -from eve.clients import common -from eve.agent import Agent -from eve.llm import UpdateType -from eve.user import User -from eve.eden_utils import prepare_result -from eve.models import ClientType +from ... import load_env +from ...clients import common +from ...agent import Agent +from ...llm import UpdateType +from ...user import User +from ...eden_utils import prepare_result +from ...models import ClientType def replace_mentions_with_usernames( @@ -41,13 +42,11 @@ def __init__( self, bot: commands.bot, agent: Agent, - db: str = "STAGE", local: bool = False, ) -> None: self.bot = bot self.agent = agent - self.db = db - self.tools = agent.get_tools(db=self.db) + self.tools = agent.get_tools() self.known_users = {} self.known_threads = {} if local: @@ -128,7 +127,7 @@ async def async_callback(message): elif update_type == UpdateType.TOOL_COMPLETE: result = data.get("result", {}) - result["result"] = prepare_result(result["result"], db=self.db) + result["result"] = prepare_result(result["result"]) url = result["result"][0]["output"][0]["url"] await self.send_message(channel, url, reference=reference) @@ -168,15 +167,15 @@ async def on_message(self, message: discord.Message) -> None: # Lookup thread if thread_key not in self.known_threads: self.known_threads[thread_key] = self.agent.request_thread( - key=thread_key, - db=self.db, + key=thread_key ) thread = self.known_threads[thread_key] # Lookup user if message.author.id not in self.known_users: self.known_users[message.author.id] = User.from_discord( - message.author.id, message.author.name, db=self.db + message.author.id, + message.author.name ) user = self.known_users[message.author.id] @@ -310,18 +309,17 @@ async def on_message(self, message: discord.Message) -> None: def start( env: str, - db: str = "STAGE", local: bool = False, ) -> None: load_dotenv(env) agent_name = os.getenv("EDEN_AGENT_USERNAME") - agent = Agent.load(agent_name, db=db) + agent = Agent.load(agent_name) print(f"Launching Discord bot {agent.username}...") bot_token = os.getenv("CLIENT_DISCORD_TOKEN") bot = DiscordBot() - bot.add_cog(Eden2Cog(bot, agent, db=db, local=local)) + bot.add_cog(Eden2Cog(bot, agent, local=local)) bot.run(bot_token) @@ -334,4 +332,6 @@ def start( ) parser.add_argument("--local", help="Run locally", action="store_true") args = parser.parse_args() - start(args.env, args.agent, args.db, args.local) + + load_env(args.db) + start(args.env, args.agent, args.local) diff --git a/eve/eden_utils.py b/eve/eden_utils.py index 19a719f..f99c19b 100644 --- a/eve/eden_utils.py +++ b/eve/eden_utils.py @@ -26,7 +26,7 @@ from . import s3 -def prepare_result(result, db: str, summarize=False): +def prepare_result(result, summarize=False): if isinstance(result, dict): if "error" in result: return result @@ -34,31 +34,31 @@ def prepare_result(result, db: str, summarize=False): result["mediaAttributes"].pop("blurhash", None) if "filename" in result: filename = result.pop("filename") - url = s3.get_full_url(filename, db) + url = s3.get_full_url(filename) if summarize: return url else: result["url"] = url - return {k: prepare_result(v, db, summarize) for k, v in result.items()} + return {k: prepare_result(v, summarize) for k, v in result.items()} elif isinstance(result, list): - return [prepare_result(item, db, summarize) for item in result] + return [prepare_result(item, summarize) for item in result] else: return result -def upload_result(result, db: str, save_thumbnails=False, save_blurhash=False): +def upload_result(result, save_thumbnails=False, save_blurhash=False): if isinstance(result, dict): - return {k: upload_result(v, db, save_thumbnails=save_thumbnails, save_blurhash=save_blurhash) for k, v in result.items()} + return {k: upload_result(v, save_thumbnails=save_thumbnails, save_blurhash=save_blurhash) for k, v in result.items()} elif isinstance(result, list): - return [upload_result(item, db, save_thumbnails=save_thumbnails, save_blurhash=save_blurhash) for item in result] + return [upload_result(item, save_thumbnails=save_thumbnails, save_blurhash=save_blurhash) for item in result] elif isinstance(result, str) and is_file(result): - return upload_media(result, db, save_thumbnails=save_thumbnails, save_blurhash=save_blurhash) + return upload_media(result, save_thumbnails=save_thumbnails, save_blurhash=save_blurhash) else: return result -def upload_media(output, db, save_thumbnails=True, save_blurhash=True): - file_url, sha = s3.upload_file(output, db=db) +def upload_media(output, save_thumbnails=True, save_blurhash=True): + file_url, sha = s3.upload_file(output) filename = file_url.split("/")[-1] media_attributes, thumbnail = get_media_attributes(output) @@ -70,8 +70,8 @@ def upload_media(output, db, save_thumbnails=True, save_blurhash=True): (width, 2560), Image.Resampling.LANCZOS ) if width < thumbnail.width else thumbnail img_bytes = PIL_to_bytes(img) - s3.upload_buffer(img_bytes, name=f"{sha}_{width}", file_type=".webp", db=db) - s3.upload_buffer(img_bytes, name=f"{sha}_{width}", file_type=".jpg", db=db) + s3.upload_buffer(img_bytes, name=f"{sha}_{width}", file_type=".webp") + s3.upload_buffer(img_bytes, name=f"{sha}_{width}", file_type=".jpg") if save_blurhash and thumbnail: try: @@ -81,7 +81,7 @@ def upload_media(output, db, save_thumbnails=True, save_blurhash=True): except Exception as e: print(f"Error encoding blurhash: {e}") - return {"filename": filename, "mediaAttributes": media_attributes, "file_url": file_url} + return {"filename": filename, "mediaAttributes": media_attributes} def get_media_attributes(file_path): @@ -193,7 +193,7 @@ def mock_image(args): draw.text((5, 5), wrapped_text, fill="black", font=font) image = image.resize((512, 512), Image.LANCZOS) buffer = PIL_to_bytes(image) - url, _ = s3.upload_buffer(buffer, db="STAGE") + url, _ = s3.upload_buffer(buffer) return url diff --git a/eve/llm.py b/eve/llm.py index 01d0b52..c608f23 100644 --- a/eve/llm.py +++ b/eve/llm.py @@ -41,7 +41,6 @@ async def async_anthropic_prompt( model: str, response_model: Optional[type[BaseModel]], tools: Dict[str, Tool], - db: str, ): anthropic_client = anthropic.AsyncAnthropic() prompt = { @@ -68,7 +67,7 @@ async def async_anthropic_prompt( [r.text for r in response.content if r.type == "text" and r.text] ) tool_calls = [ - ToolCall.from_anthropic(r, db=db) + ToolCall.from_anthropic(r) for r in response.content if r.type == "tool_use" ] @@ -82,7 +81,6 @@ async def async_anthropic_prompt_stream( model: str, response_model: Optional[type[BaseModel]], tools: Dict[str, Tool], - db: str, ) -> AsyncGenerator[Tuple[UpdateType, str], None]: """Yields partial tokens (ASSISTANT_TOKEN, partial_text) for streaming.""" anthropic_client = anthropic.AsyncAnthropic() @@ -119,7 +117,7 @@ async def async_anthropic_prompt_stream( elif chunk.type == "content_block_stop" and hasattr(chunk, "content_block"): if chunk.content_block.type == "tool_use": tool_calls.append( - ToolCall.from_anthropic(chunk.content_block, db=db) + ToolCall.from_anthropic(chunk.content_block) ) # Stop reason @@ -138,7 +136,6 @@ async def async_openai_prompt( model: str = "gpt-4o-mini", # "gpt-4o-2024-08-06", response_model: Optional[type[BaseModel]] = None, tools: Dict[str, Tool] = {}, - db: str = "STAGE", ): if not os.getenv("OPENAI_API_KEY"): raise ValueError("OPENAI_API_KEY env is not set") @@ -167,7 +164,7 @@ async def async_openai_prompt( response = response.choices[0] content = response.message.content or "" tool_calls = [ - ToolCall.from_openai(t, db=db) for t in response.message.tool_calls or [] + ToolCall.from_openai(t) for t in response.message.tool_calls or [] ] stop = response.finish_reason == "stop" @@ -204,7 +201,6 @@ async def async_prompt( model: str, response_model: Optional[type[BaseModel]] = None, tools: Dict[str, Tool] = {}, - db: str = "STAGE", ) -> Tuple[str, List[ToolCall], bool]: """ Non-streaming LLM call => returns (content, tool_calls, stop). @@ -212,12 +208,12 @@ async def async_prompt( if model.startswith("claude"): # Use the non-stream Anthropics helper return await async_anthropic_prompt( - messages, system_message, model, response_model, tools, db + messages, system_message, model, response_model, tools ) else: # Use existing OpenAI path return await async_openai_prompt( - messages, system_message, model, response_model, tools, db + messages, system_message, model, response_model, tools ) @@ -251,7 +247,6 @@ async def async_prompt_stream( model: str, response_model: Optional[type[BaseModel]] = None, tools: Dict[str, Tool] = {}, - db: str = "STAGE", ) -> AsyncGenerator[Tuple[UpdateType, str], None]: """ Streaming LLM call => yields (UpdateType.ASSISTANT_TOKEN, partial_text). @@ -260,7 +255,7 @@ async def async_prompt_stream( if model.startswith("claude"): # Stream from Anthropics async for chunk in async_anthropic_prompt_stream( - messages, system_message, model, response_model, tools, db + messages, system_message, model, response_model, tools ): yield chunk else: @@ -330,7 +325,6 @@ async def async_think(): async def async_prompt_thread( - db: str, user: User, agent: Agent, thread: Thread, @@ -403,7 +397,6 @@ async def async_prompt_thread( system_message=system_message, model=model, tools=tools, - db=db, ): # stream an individual token if update_type == UpdateType.ASSISTANT_TOKEN: @@ -431,7 +424,6 @@ async def async_prompt_thread( system_message=system_message, model=model, tools=tools, - db=db, ) # for error tracing @@ -493,7 +485,7 @@ async def async_prompt_thread( # start task task = await tool.async_start_task( - user.id, agent.id, tool_call.args, db=db + user.id, agent.id, tool_call.args ) # update tool call with task id and status @@ -542,14 +534,12 @@ async def async_prompt_thread( ) if stop: - print("Stopping prompt thread") break yield ThreadUpdate(type=UpdateType.END_PROMPT) def prompt_thread( - db: str, user: User, agent: Agent, thread: Thread, @@ -559,7 +549,7 @@ def prompt_thread( model: Literal[tuple(models)] = "claude-3-5-sonnet-20241022", ): async_gen = async_prompt_thread( - db, user, agent, thread, user_messages, tools, force_reply, model + user, agent, thread, user_messages, tools, force_reply, model ) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) diff --git a/eve/models.py b/eve/models.py index 0a99c10..c783921 100644 --- a/eve/models.py +++ b/eve/models.py @@ -19,6 +19,8 @@ class Model(Document): args: Dict[str, Any] checkpoint: str base_model: str + lora_trigger_text: Optional[str] = None + # users: SkipJsonSchema[Optional[Collection]] = Field(None, exclude=True) # def __init__(self, env, **data): @@ -61,7 +63,7 @@ def _make_slug(self): if doc.get("slug") ] new_version = max(versions or [0]) + 1 - users = get_collection("users3", db=self.db) + users = get_collection("users3") username = users.find_one({"_id": self.user})["username"] # username = self.users.find_one({"_id": self.user})["username"] self.slug = f"{username}/{name}/v{new_version}" diff --git a/eve/mongo.py b/eve/mongo.py index 7c4a514..0c15245 100644 --- a/eve/mongo.py +++ b/eve/mongo.py @@ -9,25 +9,15 @@ from typing import Optional, List, Dict, Any, Union -MONGO_URI = os.getenv("MONGO_URI") -MONGO_DB_NAME = os.getenv("MONGO_DB_NAME") - -if not all([MONGO_URI, MONGO_DB_NAME]): - print("WARNING: MONGO_URI and MONGO_DB_NAME must be set in the environment") - # Global connection pool _mongo_client = None _collections = {} -def get_collection(collection_name: str, db: str): - """Get a MongoDB collection with connection pooling""" +def get_mongo_client(): + """Get a MongoDB client with connection pooling""" global _mongo_client - - cache_key = f"{db}:{collection_name}" - if cache_key in _collections: - return _collections[cache_key] - + MONGO_URI = os.getenv("MONGO_URI") if _mongo_client is None: _mongo_client = MongoClient( MONGO_URI, @@ -39,8 +29,19 @@ def get_collection(collection_name: str, db: str): retryWrites=True, server_api=ServerApi("1"), ) + return _mongo_client - _collections[cache_key] = _mongo_client[MONGO_DB_NAME][collection_name] + +def get_collection(collection_name: str): + """Get a MongoDB collection with connection pooling""" + db = os.getenv("DB") + cache_key = f"{db}:{collection_name}" + if cache_key in _collections: + return _collections[cache_key] + + MONGO_DB_NAME = os.getenv("MONGO_DB_NAME") + mongo_client = get_mongo_client() + _collections[cache_key] = mongo_client[MONGO_DB_NAME][collection_name] return _collections[cache_key] @@ -58,7 +59,6 @@ class Document(BaseModel): default_factory=lambda: datetime.now(timezone.utc) ) updatedAt: Optional[datetime] = None - db: Optional[str] = None model_config = ConfigDict( json_encoders={ @@ -70,24 +70,22 @@ class Document(BaseModel): ) @classmethod - def get_collection(cls, db=None): + def get_collection(cls): """ Override this method to provide the correct collection for the model. """ - db = db or cls.db or "STAGE" collection_name = getattr(cls, "collection_name", cls.__name__.lower()) - return get_collection(collection_name, db) + return get_collection(collection_name) @classmethod - def from_schema(cls, schema: dict, db="STAGE", from_yaml=True): + def from_schema(cls, schema: dict, from_yaml=True): """Load a document from a schema.""" - schema["db"] = db - sub_cls = cls.get_sub_class(schema, from_yaml=from_yaml, db=db) + sub_cls = cls.get_sub_class(schema, from_yaml=from_yaml) result = sub_cls.model_validate(schema) return result @classmethod - def from_yaml(cls, file_path: str, db="STAGE"): + def from_yaml(cls, file_path: str): """ Load a document from a YAML file. """ @@ -95,45 +93,46 @@ def from_yaml(cls, file_path: str, db="STAGE"): raise FileNotFoundError(f"File {file_path} not found") with open(file_path, "r") as file: schema = yaml.safe_load(file) - sub_cls = cls.get_sub_class(schema, from_yaml=True, db=db) + sub_cls = cls.get_sub_class(schema, from_yaml=True) schema = sub_cls.convert_from_yaml(schema, file_path=file_path) - return cls.from_schema(schema, db=db, from_yaml=True) + return cls.from_schema(schema, from_yaml=True) @classmethod - def from_mongo(cls, document_id: ObjectId, db="STAGE"): + def from_mongo(cls, document_id: ObjectId): """ Load the document from the database and return an instance of the model. """ document_id = ( document_id if isinstance(document_id, ObjectId) else ObjectId(document_id) ) - schema = cls.get_collection(db).find_one({"_id": document_id}) + schema = cls.get_collection().find_one({"_id": document_id}) if not schema: + db = os.getenv("DB") raise ValueError( f"Document {document_id} not found in {cls.collection_name}:{db}" ) - sub_cls = cls.get_sub_class(schema, from_yaml=False, db=db) - schema = sub_cls.convert_from_mongo(schema, db=db) - return cls.from_schema(schema, db, from_yaml=False) + sub_cls = cls.get_sub_class(schema, from_yaml=False) + schema = sub_cls.convert_from_mongo(schema) + return cls.from_schema(schema, from_yaml=False) @classmethod - def load(cls, db="STAGE", **kwargs): + def load(cls, **kwargs): """ Load the document from the database and return an instance of the model. """ - schema = cls.get_collection(db).find_one(kwargs) + schema = cls.get_collection().find_one(kwargs) if not schema: - raise MongoDocumentNotFound(cls.collection_name, db, **kwargs) - sub_cls = cls.get_sub_class(schema, from_yaml=False, db=db) - schema = sub_cls.convert_from_mongo(schema, db=db) - return cls.from_schema(schema, db, from_yaml=False) + raise MongoDocumentNotFound(cls.collection_name, **kwargs) + sub_cls = cls.get_sub_class(schema, from_yaml=False) + schema = sub_cls.convert_from_mongo(schema) + return cls.from_schema(schema, from_yaml=False) @classmethod - def get_sub_class(cls, schema: dict = None, db="STAGE", from_yaml=True) -> type: + def get_sub_class(cls, schema: dict = None, from_yaml=True) -> type: return cls @classmethod - def convert_from_mongo(cls, schema: dict, db="STAGE", **kwargs) -> dict: + def convert_from_mongo(cls, schema: dict, **kwargs) -> dict: return schema @classmethod @@ -148,20 +147,18 @@ def convert_to_mongo(cls, schema: dict, **kwargs) -> dict: def convert_to_yaml(cls, schema: dict, **kwargs) -> dict: return schema - def save(self, db=None, upsert_filter=None, **kwargs): + def save(self, upsert_filter=None, **kwargs): """ Save the current state of the model to the database. """ - db = db or self.db or "STAGE" - - schema = self.model_dump(by_alias=True, exclude={"db"}) + schema = self.model_dump(by_alias=True) self.model_validate(schema) schema = self.convert_to_mongo(schema) schema.update(kwargs) self.updatedAt = datetime.now(timezone.utc) filter = upsert_filter or {"_id": self.id or ObjectId()} - collection = self.get_collection(db) + collection = self.get_collection() if self.id or filter: if filter: schema.pop("_id", None) @@ -177,15 +174,13 @@ def save(self, db=None, upsert_filter=None, **kwargs): self.createdAt = datetime.now(timezone.utc) result = collection.insert_one(schema) self.id = schema["_id"] - self.db = db @classmethod - def save_many(cls, documents: List[BaseModel], db=None): - db = db or cls.db or "STAGE" - collection = cls.get_collection(db) + def save_many(cls, documents: List[BaseModel]): + collection = cls.get_collection() for d in range(len(documents)): documents[d].id = documents[d].id or ObjectId() - documents[d] = documents[d].model_dump(by_alias=True, exclude={"db"}) + documents[d] = documents[d].model_dump(by_alias=True) cls.model_validate(documents[d]) documents[d] = cls.convert_to_mongo(documents[d]) documents[d]["createdAt"] = documents[d].get( @@ -200,7 +195,7 @@ def update(self, **kwargs): """ Perform granular updates on specific fields. """ - collection = self.get_collection(self.db) + collection = self.get_collection() update_result = collection.update_one( {"_id": self.id}, {"$set": kwargs, "$currentDate": {"updatedAt": True}} ) @@ -212,7 +207,7 @@ def set_against_filter(self, updates: Dict = None, filter: Optional[Dict] = None """ Perform granular updates on specific fields, given an optional filter. """ - collection = self.get_collection(self.db) + collection = self.get_collection() update_result = collection.update_one( {"_id": self.id, **filter}, {"$set": updates, "$currentDate": {"updatedAt": True}}, @@ -264,7 +259,7 @@ def push( setattr(self, field_name, [x for x in current_list if x != value]) # Update MongoDB operation to use $pull instead of $pop - collection = self.get_collection(self.db) + collection = self.get_collection() update_ops = {"$currentDate": {"updatedAt": True}} if push_ops: update_ops["$push"] = push_ops @@ -296,7 +291,7 @@ def update_nested_field(self, field_name: str, index: int, sub_field: str, value raise ValidationError(f"Field '{field_name}' is not a valid list field.") # Perform the update operation in MongoDB - collection = self.get_collection(self.db) + collection = self.get_collection() update_result = collection.update_one( {"_id": self.id}, { @@ -318,7 +313,7 @@ def reload(self): """ Reload the current document from the database to ensure the instance is up-to-date. """ - updated_instance = self.from_mongo(self.id, self.db) + updated_instance = self.from_mongo(self.id) if updated_instance: # Use model_dump to get the data while maintaining type information for key, value in updated_instance.model_dump().items(): @@ -328,7 +323,7 @@ def delete(self): """ Delete the document from the database. """ - collection = self.get_collection(self.db) + collection = self.get_collection() collection.delete_one({"_id": self.id}) @@ -348,8 +343,11 @@ class MongoDocumentNotFound(Exception): """Exception raised when a document is not found in MongoDB.""" def __init__( - self, collection_name: str, db: str, document_id: str = None, **kwargs + self, collection_name: str, + document_id: str = None, + **kwargs ): + db = os.getenv("DB") if document_id: self.message = f"Document with id {document_id} not found in collection {collection_name}, db: {db}" else: diff --git a/eve/s3.py b/eve/s3.py index 8e9cb59..ebb9c26 100644 --- a/eve/s3.py +++ b/eve/s3.py @@ -11,24 +11,12 @@ from typing import Iterator from PIL import Image -AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") -AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") -AWS_REGION_NAME = os.getenv("AWS_REGION_NAME") -AWS_BUCKET_NAME = os.getenv("AWS_BUCKET_NAME") - -if not all( - [AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME, AWS_BUCKET_NAME] -): - # raise ValueError("AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME, AWS_BUCKET_NAME must be set in the environment") - print( - "WARNING: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME, and AWS_BUCKET_NAME must be set in the environment" - ) s3 = boto3.client( "s3", - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - region_name=AWS_REGION_NAME, + aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), + region_name=os.getenv("AWS_REGION_NAME"), ) @@ -46,21 +34,24 @@ } -def get_root_url(db): +def get_root_url(): """Returns the root URL for the specified bucket.""" + AWS_BUCKET_NAME = os.getenv("AWS_BUCKET_NAME") + AWS_REGION_NAME = os.getenv("AWS_REGION_NAME") url = f"https://{AWS_BUCKET_NAME}.s3.{AWS_REGION_NAME}.amazonaws.com" return url -def get_full_url(filename, db): - return f"{get_root_url(db=db)}/{filename}" +def get_full_url(filename): + return f"{get_root_url()}/{filename}" -def upload_file_from_url(url, name=None, file_type=None, db="STAGE"): +def upload_file_from_url(url, name=None, file_type=None): """Uploads a file to an S3 bucket by downloading it to a temporary file and uploading it to S3.""" + AWS_BUCKET_NAME = os.getenv("AWS_BUCKET_NAME") if f"{AWS_BUCKET_NAME}.s3." in url and ".amazonaws.com" in url: - # print(f"File is already uploaded at {url}") + # file is already uploaded filename = url.split("/")[-1].split(".")[0] return url, filename @@ -71,25 +62,25 @@ def upload_file_from_url(url, name=None, file_type=None, db="STAGE"): tmp_file.write(chunk) tmp_file.flush() tmp_file.seek(0) - return upload_file(tmp_file.name, name, file_type, db) + return upload_file(tmp_file.name, name, file_type) -def upload_file(file_path, name=None, file_type=None, db="STAGE"): +def upload_file(file_path, name=None, file_type=None): """Uploads a file to an S3 bucket and returns the file URL.""" if file_path.endswith(".safetensors"): file_type = ".safetensors" if file_path.startswith("http://") or file_path.startswith("https://"): - return upload_file_from_url(file_path, name, file_type, db) + return upload_file_from_url(file_path, name, file_type) with open(file_path, "rb") as file: buffer = file.read() - return upload_buffer(buffer, name, file_type, db) + return upload_buffer(buffer, name, file_type) -def upload_buffer(buffer, name=None, file_type=None, db="STAGE"): +def upload_buffer(buffer, name=None, file_type=None): """Uploads a buffer to an S3 bucket and returns the file URL.""" assert ( @@ -148,7 +139,7 @@ def upload_buffer(buffer, name=None, file_type=None, db="STAGE"): # Upload file to S3 filename = f"{name}{file_type}" file_bytes = io.BytesIO(buffer) - bucket_name = AWS_BUCKET_NAME + bucket_name = os.getenv("AWS_BUCKET_NAME") file_url = f"https://{bucket_name}.s3.amazonaws.com/{filename}" # if file doesn't exist, upload it @@ -169,29 +160,29 @@ def upload_buffer(buffer, name=None, file_type=None, db="STAGE"): return file_url, name -def upload_PIL_image(image: Image.Image, name=None, file_type=None, db="STAGE"): +def upload_PIL_image(image: Image.Image, name=None, file_type=None): format = file_type.split(".")[-1] or "webp" buffer = io.BytesIO() image.save(buffer, format=format) - return upload_buffer(buffer, name, file_type, db) + return upload_buffer(buffer, name, file_type) -def upload_audio_segment(audio: AudioSegment, db="STAGE"): +def upload_audio_segment(audio: AudioSegment): buffer = io.BytesIO() audio.export(buffer, format="mp3") - output = upload_buffer(buffer, db=db) + output = upload_buffer(buffer) return output -def upload(data: any, name=None, file_type=None, db="STAGE"): +def upload(data: any, name=None, file_type=None): if isinstance(data, Image.Image): - return upload_PIL_image(data, name, file_type, db) + return upload_PIL_image(data, name, file_type) elif isinstance(data, AudioSegment): - return upload_audio_segment(data, db) + return upload_audio_segment(data) elif isinstance(data, bytes): - return upload_buffer(data, name, file_type, db) + return upload_buffer(data, name, file_type) else: - return upload_file(data, name, file_type, db) + return upload_file(data, name, file_type) def copy_file_to_bucket(source_bucket, dest_bucket, source_key, dest_key=None): diff --git a/eve/soc.py b/eve/soc.py index 2dea347..53d1db7 100644 --- a/eve/soc.py +++ b/eve/soc.py @@ -10,13 +10,13 @@ from eve.auth import get_my_eden_user db = "STAGE" -user = get_my_eden_user(db=db) +user = get_my_eden_user() thread_name_think = "test_soc23_think" thread_name_act = "test_soc23_act" tools = get_tools_from_mongo(db=db) -thread_think = Thread.from_name(name=thread_name_think, user=user.id, db=db) -thread_act = Thread.from_name(name=thread_name_act, user=user.id, db=db) +thread_think = Thread.from_name(name=thread_name_think, user=user.id) +thread_act = Thread.from_name(name=thread_name_act, user=user.id) thread_act.push("messages", UserMessage(content="I am your inner voice, guiding you as you attempt to creare digital art for me to critique. As you create art, I will give you feedback and instruct you on ways to improve your work. Please refrain from restating what I tell you or being verbose. Just make the art, tell me concisely what you are doing, and show me the results.")) @@ -47,7 +47,7 @@ def convert_assistant_messages(messages: List[AssistantMessage]): -# def prompt_thread2(db, user_id, thread_name, user_message, tools): +# def prompt_thread2(user_id, thread_name, user_message, tools): # async def run(): # return all_messages @@ -62,7 +62,7 @@ async def main(): while True: think_messages = [] - async for message in async_prompt_thread(db, user_id, thread_name_think, input_message, {}): + async for message in async_prompt_thread(user_id, thread_name_think, input_message, {}): print_message(message, name="Eve 1") think_messages.append(message) diff --git a/eve/task.py b/eve/task.py index 56c7366..c5b598a 100644 --- a/eve/task.py +++ b/eve/task.py @@ -5,7 +5,7 @@ import asyncio import traceback -from .user import User +from .user import User, Manna, Transaction from .mongo import Document, Collection from . import eden_utils from . import sentry_sdk @@ -59,12 +59,36 @@ def __init__(self, **data): super().__init__(**data) @classmethod - def from_handler_id(self, handler_id, db): - tasks = self.get_collection(db) + def from_handler_id(self, handler_id): + tasks = self.get_collection() task = tasks.find_one({"handler_id": handler_id}) if not task: raise Exception("Task not found") - return super().load(self, task["_id"], db) + return super().load(self, task["_id"]) + + def spend_manna(self): + if self.cost == 0: + return + manna = Manna.load(self.requester) + manna.spend(self.cost) + Transaction( + manna=manna.id, + task=self.id, + amount=self.cost, + type="spend", + ).save() + + def refund_manna(self): + n_samples = self.args.get("n_samples", 1) + refund_amount = (self.cost or 0) * (n_samples - len(self.result or [])) / n_samples + manna = Manna.load(self.requester) + manna.refund(refund_amount) + Transaction( + manna=manna.id, + task=self.id, + amount=refund_amount, + type="refund", + ).save() def task_handler_func(func): @@ -108,13 +132,13 @@ async def _task_handler(func, *args, **kwargs): task_args["seed"] = task_args["seed"] + i # Run both functions concurrently - main_task = func(*args[:-1], task.parent_tool or task.tool, task_args, task.db) + main_task = func(*args[:-1], task.parent_tool or task.tool, task_args) preprocess_task = _preprocess_task(task) result, preprocess_result = await asyncio.gather(main_task, preprocess_task) if output_type in ["image", "video", "audio", "lora"]: result["output"] = result["output"] if isinstance(result["output"], list) else [result["output"]] - result = eden_utils.upload_result(result, db=task.db, save_thumbnails=True, save_blurhash=True) + result = eden_utils.upload_result(result, save_thumbnails=True, save_blurhash=True) for output in result["output"]: name = preprocess_result.get("name") or task_args.get("prompt") or args.get("text_input") @@ -131,7 +155,7 @@ async def _task_handler(func, *args, **kwargs): mediaAttributes=output['mediaAttributes'], name=name ) - new_creation.save(db=task.db) + new_creation.save() output["creation"] = new_creation.id results.extend([result]) @@ -158,12 +182,8 @@ async def _task_handler(func, *args, **kwargs): task_update = { "status": "failed", "error": str(error), - } - - n_samples = task.args.get("n_samples", 1) - refund_amount = (task.cost or 0) * (n_samples - len(task.result or [])) / n_samples - user = User.from_mongo(task.user, db=task.db) - user.refund_manna(refund_amount) + } + task.refund_manna() return task_update.copy() diff --git a/eve/thread.py b/eve/thread.py index b0415f7..7e9b268 100644 --- a/eve/thread.py +++ b/eve/thread.py @@ -147,7 +147,6 @@ class ToolCall(BaseModel): tool: str args: Dict[str, Any] - db: SkipJsonSchema[str] task: Optional[ObjectId] = None status: Optional[ Literal["pending", "running", "completed", "failed", "cancelled"] @@ -162,7 +161,7 @@ def get_result(self, schema, truncate_images=False): result = {"status": self.status} if self.status == "completed": - result["result"] = prepare_result(self.result, db=self.db) + result["result"] = prepare_result(self.result) outputs = [ o.get("url") for r in result.get("result", []) @@ -247,18 +246,17 @@ def react(self, user: ObjectId, reaction: str): pass @staticmethod - def from_openai(tool_call, db): + def from_openai(tool_call): return ToolCall( id=tool_call.id, tool=tool_call.function.name, args=json.loads(tool_call.function.arguments), - db=db, ) @staticmethod - def from_anthropic(tool_call, db): + def from_anthropic(tool_call): return ToolCall( - id=tool_call.id, tool=tool_call.name, args=tool_call.input, db=db + id=tool_call.id, tool=tool_call.name, args=tool_call.input ) def openai_call_schema(self): @@ -282,7 +280,8 @@ def anthropic_result_schema(self, truncate_images=False): "type": "tool_result", "tool_use_id": self.id, "content": self.get_result( - schema="anthropic", truncate_images=truncate_images + schema="anthropic", + truncate_images=truncate_images ), } @@ -291,7 +290,8 @@ def openai_result_schema(self, truncate_images=False): "role": "tool", "name": self.tool, "content": self.get_result( - schema="openai", truncate_images=truncate_images + schema="openai", + truncate_images=truncate_images ), "tool_call_id": self.id, } @@ -361,20 +361,21 @@ class Thread(Document): active: List[ObjectId] = Field(default_factory=list) @classmethod - def load(cls, key, agent=None, user=None, create_if_missing=False, db="STAGE"): + def load(cls, key, agent=None, user=None, create_if_missing=False): filter = {"key": key} if agent: filter["agent"] = agent if user: filter["user"] = user - thread = cls.get_collection(db).find_one(filter) + thread = cls.get_collection().find_one(filter) if thread: - thread = Thread(db=db, **thread) + thread = Thread(**thread) else: if create_if_missing: - thread = cls(db=db, key=key, agent=agent, user=user) + thread = cls(key=key, agent=agent, user=user) thread.save() else: + db = os.getenv("DB") raise Exception(f"Thread {key} with agent {agent} not found in {cls.collection_name}:{db}") return thread diff --git a/eve/tool.py b/eve/tool.py index 189d1a6..cc6f38e 100644 --- a/eve/tool.py +++ b/eve/tool.py @@ -86,7 +86,7 @@ class Tool(Document, ABC): test_args: Optional[Dict[str, Any]] = None @classmethod - def _get_schema(cls, key, db, from_yaml=False) -> dict: + def _get_schema(cls, key, from_yaml=False) -> dict: """Get schema for a tool, with detailed performance logging.""" if from_yaml: @@ -104,14 +104,16 @@ def _get_schema(cls, key, db, from_yaml=False) -> dict: schema["workspace"] = schema.get("workspace") or api_file.split("/")[-4] else: # MongoDB path - collection = get_collection(cls.collection_name, db=db) + collection = get_collection(cls.collection_name) schema = collection.find_one({"key": key}) return schema @classmethod def get_sub_class( - cls, schema, db, from_yaml=False + cls, + schema, + from_yaml=False ) -> type: from .tools.local_tool import LocalTool from .tools.modal_tool import ModalTool @@ -121,7 +123,7 @@ def get_sub_class( parent_tool = schema.get("parent_tool") if parent_tool: - parent_schema = cls._get_schema(parent_tool, db, from_yaml) + parent_schema = cls._get_schema(parent_tool, from_yaml) handler = parent_schema.get("handler") else: handler = schema.get("handler") @@ -148,7 +150,7 @@ def convert_from_yaml(cls, schema: dict, file_path: str = None) -> dict: parent_tool = schema.get("parent_tool") if parent_tool: - parent_schema = cls._get_schema(parent_tool, db=None, from_yaml=True) + parent_schema = cls._get_schema(parent_tool, from_yaml=True) parent_schema["parameter_presets"] = schema.pop("parameters", {}) parent_parameters = parent_schema.pop("parameters", {}) for k, v in parent_schema["parameter_presets"].items(): @@ -178,7 +180,7 @@ def convert_from_yaml(cls, schema: dict, file_path: str = None) -> dict: return schema @classmethod - def convert_from_mongo(cls, schema, db) -> dict: + def convert_from_mongo(cls, schema) -> dict: schema["parameters"] = { p["name"]: {**(p.pop("schema")), **p} for p in schema["parameters"] } @@ -205,42 +207,41 @@ def convert_to_mongo(cls, schema: dict) -> dict: return schema - def save(self, db=None, **kwargs): - return super().save(db, {"key": self.key}, **kwargs) + def save(self, **kwargs): + return super().save({"key": self.key}, **kwargs) @classmethod - def from_raw_yaml(cls, schema, db, from_yaml=True): - schema["db"] = db + def from_raw_yaml(cls, schema, from_yaml=True): schema = cls.convert_from_yaml(schema) - sub_cls = cls.get_sub_class(schema, from_yaml=from_yaml, db=db) + sub_cls = cls.get_sub_class(schema, from_yaml=from_yaml) return sub_cls.model_validate(schema) @classmethod - def from_yaml(cls, file_path, db="STAGE", cache=False): + def from_yaml(cls, file_path, cache=False): if cache: if file_path not in _tool_cache: - _tool_cache[file_path] = super().from_yaml(file_path, db=db) + _tool_cache[file_path] = super().from_yaml(file_path) return _tool_cache[file_path] else: - return super().from_yaml(file_path, db=db) + return super().from_yaml(file_path) @classmethod - def from_mongo(cls, document_id, db="STAGE", cache=False): + def from_mongo(cls, document_id, cache=False): if cache: if document_id not in _tool_cache: - _tool_cache[str(document_id)] = super().from_mongo(document_id, db=db) + _tool_cache[str(document_id)] = super().from_mongo(document_id) return _tool_cache[str(document_id)] else: - return super().from_mongo(document_id, db=db) + return super().from_mongo(document_id) @classmethod - def load(cls, key, db="STAGE", cache=False): + def load(cls, key, cache=False): if cache: if key not in _tool_cache: - _tool_cache[key] = super().load(key=key, db=db) + _tool_cache[key] = super().load(key=key) return _tool_cache[key] else: - return super().load(key=key, db=db) + return super().load(key=key) def _remove_hidden_fields(self, parameters): hidden_parameters = [ @@ -277,9 +278,8 @@ def calculate_cost(self, args): cost_formula = re.sub( r"(\w+)\s*\?\s*([^:]+)\s*:\s*([^,\s]+)", r"\2 if \1 else \3", cost_formula ) # Ternary operator - cost_estimate = eval(cost_formula, args.copy()) - assert isinstance(cost_estimate, (int, float)), "Cost estimate not a number" + assert isinstance(cost_estimate, (int, float)), f"Cost estimate ({cost_estimate}) not a number (formula: {cost_formula})" return cost_estimate def prepare_args(self, args: dict): @@ -312,21 +312,21 @@ def prepare_args(self, args: dict): def handle_run(run_function): """Wrapper for calling a tool directly and waiting for the result""" - async def async_wrapper(self, args: Dict, db: str, mock: bool = False): + async def async_wrapper(self, args: Dict, mock: bool = False): try: args = self.prepare_args(args) sentry_sdk.add_breadcrumb(category="handle_run", data=args) if mock: result = {"output": eden_utils.mock_image(args)} else: - result = await run_function(self, args, db) + result = await run_function(self, args) result["output"] = ( result["output"] if isinstance(result["output"], list) else [result["output"]] ) sentry_sdk.add_breadcrumb(category="handle_run", data=result) - result = eden_utils.upload_result(result, db) + result = eden_utils.upload_result(result) sentry_sdk.add_breadcrumb(category="handle_run", data=result) result["status"] = "completed" except Exception as e: @@ -345,7 +345,6 @@ async def async_wrapper( requester_id: str, user_id: str, args: Dict, - db: str, mock: bool = False, ): try: @@ -353,10 +352,11 @@ async def async_wrapper( args = self.prepare_args(args) sentry_sdk.add_breadcrumb(category="handle_start_task", data=args) cost = self.calculate_cost(args) - user = User.from_mongo(user_id, db=db) + user = User.from_mongo(user_id) if "freeTools" in (user.featureFlags or []): cost = 0 - user.check_manna(cost) + requester = User.from_mongo(requester_id) + requester.check_manna(cost) except Exception as e: print(traceback.format_exc()) @@ -373,7 +373,7 @@ async def async_wrapper( mock=mock, cost=cost, ) - task.save(db=db) + task.save() sentry_sdk.add_breadcrumb(category="handle_start_task", data=task.model_dump()) # start task @@ -381,7 +381,7 @@ async def async_wrapper( if mock: handler_id = eden_utils.random_string() output = {"output": eden_utils.mock_image(args)} - result = eden_utils.upload_result(output, db=db) + result = eden_utils.upload_result(output) task.update( handler_id=handler_id, status="completed", @@ -396,8 +396,8 @@ async def async_wrapper( handler_id = await start_task_function(self, task) task.update(handler_id=handler_id) - user.spend_manna(task.cost) - + task.spend_manna() + except Exception as e: print(traceback.format_exc()) task.update(status="failed", error=str(e)) @@ -431,12 +431,7 @@ def handle_cancel(cancel_function): async def async_wrapper(self, task: Task): await cancel_function(self, task) - n_samples = task.args.get("n_samples", 1) - refund_amount = ( - (task.cost or 0) * (n_samples - len(task.result or [])) / n_samples - ) - user = User.from_mongo(task.user, db=task.db) - user.refund_manna(refund_amount) + task.refund_manna() task.update(status="cancelled") return async_wrapper @@ -457,13 +452,13 @@ async def async_wait(self): async def async_cancel(self): pass - def run(self, args: Dict, db: str, mock: bool = False): - return asyncio.run(self.async_run(args, db, mock)) + def run(self, args: Dict, mock: bool = False): + return asyncio.run(self.async_run(args, mock)) def start_task( - self, requester_id: str, user_id: str, args: Dict, db: str, mock: bool = False + self, requester_id: str, user_id: str, args: Dict, mock: bool = False ): - return asyncio.run(self.async_start_task(requester_id, user_id, args, db, mock)) + return asyncio.run(self.async_start_task(requester_id, user_id, args, mock)) def wait(self, task: Task): return asyncio.run(self.async_wait(task)) @@ -491,14 +486,13 @@ def get_tools_from_api_files( def get_tools_from_mongo( - db: str, tools: List[str] = None, include_inactive: bool = False, cache: bool = False, ) -> Dict[str, Tool]: """Get all tools from mongo""" - tools_collection = get_collection(Tool.collection_name, db=db) + tools_collection = get_collection(Tool.collection_name) # Batch fetch all tools and their parents filter = {"key": {"$in": tools}} if tools else {} @@ -510,8 +504,8 @@ def get_tools_from_mongo( if tool.get("key") in _tool_cache: tool = _tool_cache[tool.get("key")] else: - tool = Tool.convert_from_mongo(tool, db=db) - tool = Tool.from_schema(tool, db=db, from_yaml=False) + tool = Tool.convert_from_mongo(tool) + tool = Tool.from_schema(tool, from_yaml=False) if cache: _tool_cache[tool.key] = tool if tool.status != "inactive" and not include_inactive: diff --git a/eve/tools/comfyui_tool.py b/eve/tools/comfyui_tool.py index c82ac1e..dc4a42f 100644 --- a/eve/tools/comfyui_tool.py +++ b/eve/tools/comfyui_tool.py @@ -1,4 +1,5 @@ import modal +import os from pydantic import BaseModel, Field from typing import List, Optional, Dict @@ -35,19 +36,22 @@ def convert_from_yaml(cls, schema: dict, file_path: str = None) -> dict: return super().convert_from_yaml(schema, file_path) @Tool.handle_run - async def async_run(self, args: Dict, db: str): + async def async_run(self, args: Dict): + db = os.getenv("DB") cls = modal.Cls.lookup( f"comfyui-{self.workspace}-{db}", "ComfyUI", environment_name="main" ) - result = await cls().run.remote.aio(self.parent_tool or self.key, args, db) + result = await cls().run.remote.aio(self.parent_tool or self.key, args) return result @Tool.handle_start_task async def async_start_task(self, task: Task): + db = os.getenv("DB") cls = modal.Cls.lookup( - f"comfyui-{self.workspace}-{task.db}", + # f"comfyui-{self.workspace}-{task.db}", + f"comfyui-{self.workspace}-{db}", "ComfyUI", environment_name="main" ) diff --git a/eve/tools/elevenlabs/handler.py b/eve/tools/elevenlabs/handler.py index f00c210..63ded08 100644 --- a/eve/tools/elevenlabs/handler.py +++ b/eve/tools/elevenlabs/handler.py @@ -11,7 +11,7 @@ eleven = ElevenLabs() -async def handler(args: dict, db: str): +async def handler(args: dict): # print("args", args) args["stability"] = args.get("stability", 0.5) args["similarity_boost"] = args.get("similarity_boost", 0.75) @@ -60,7 +60,8 @@ def clone_voice(name, description, voice_urls): with NamedTemporaryFile(delete=False) as file: file = eden_utils.download_file(url, file.name) voice_files.append(file) - voice = eleven.clone(name, description, voice_files) + voice_files = ["/Users/gene/Downloads/verdelis-future of life - isolated.mp3"] + voice = eleven.clone(name, voice_files, description) for file in voice_files: os.remove(file) return voice diff --git a/eve/tools/example_tool/handler.py b/eve/tools/example_tool/handler.py index c95c334..d175143 100644 --- a/eve/tools/example_tool/handler.py +++ b/eve/tools/example_tool/handler.py @@ -9,8 +9,8 @@ async def handler(args: dict, db: str): # you can call other tools #from ...tool import Tool - #txt2img = Tool.load(key="txt2img", db=db) - #result = await txt2img.run(args, db=db) + #txt2img = Tool.load(key="txt2img") + #result = await txt2img.run(args) result = { "output": image_path, diff --git a/eve/tools/flux_dev_lora/test.json b/eve/tools/flux_dev_lora/test.json index d209237..922b35a 100644 --- a/eve/tools/flux_dev_lora/test.json +++ b/eve/tools/flux_dev_lora/test.json @@ -1,5 +1,5 @@ { - "prompt": "hello Banny, a yellow cartoon bananaman at the podium giving a speech, as president of the United States. In the background are military generals adorned with medals.", + "prompt": "hey Banny, a yellow cartoon bananaman at the podium giving a speech, as president of the United States. In the background are military generals adorned with medals.", "lora": "67611d1943808b38016c62c3", "lora_strength": 1.0, "aspect_ratio": "16:9" diff --git a/eve/tools/gcp_tool.py b/eve/tools/gcp_tool.py index 06368f5..ad994ba 100644 --- a/eve/tools/gcp_tool.py +++ b/eve/tools/gcp_tool.py @@ -14,7 +14,7 @@ class GCPTool(Tool): gpu: str @Tool.handle_run - async def async_run(self, args: Dict, db: str): + async def async_run(self, args: Dict): raise NotImplementedError("Not implemented yet, need a GCP Task ID") @Tool.handle_start_task @@ -25,7 +25,6 @@ async def async_start_task(self, task: Task): gpu=self.gpu, gpu_count=1, task_id=str(task.id), - db=task.db ) return handler_id @@ -81,9 +80,9 @@ def submit_job( machine_type, gpu, gpu_count, - task_id, - db + task_id ): + db = os.getenv("DB") aiplatform = get_ai_platform_client() job_name = f"flux-{task_id}" job = aiplatform.CustomJob( diff --git a/eve/tools/local_tool.py b/eve/tools/local_tool.py index 2448b6b..f1e0ddb 100644 --- a/eve/tools/local_tool.py +++ b/eve/tools/local_tool.py @@ -13,8 +13,8 @@ def __init__(self, *args, **kwargs): self._tasks = {} @Tool.handle_run - async def async_run(self, args: Dict, db: str): - result = await handlers[self.parent_tool or self.key](args, db=db) + async def async_run(self, args: Dict): + result = await handlers[self.parent_tool or self.key](args) return result @Tool.handle_start_task @@ -49,5 +49,5 @@ async def async_cancel(self, task: Task): @task_handler_func -async def run_task(tool_key: str, args: dict, db: str): - return await handlers[tool_key](args, db=db) +async def run_task(tool_key: str, args: dict): + return await handlers[tool_key](args) diff --git a/eve/tools/modal_tool.py b/eve/tools/modal_tool.py index c5d519d..e847f67 100644 --- a/eve/tools/modal_tool.py +++ b/eve/tools/modal_tool.py @@ -7,13 +7,13 @@ class ModalTool(Tool): @Tool.handle_run - async def async_run(self, args: Dict, db: str): + async def async_run(self, args: Dict): func = modal.Function.lookup( "modal_tools", "run", environment_name="main" ) - result = await func.remote.aio(tool_key=self.parent_tool or self.key, args=args, db=db) + result = await func.remote.aio(tool_key=self.parent_tool or self.key, args=args) return result @Tool.handle_start_task diff --git a/eve/tools/reel/handler.py b/eve/tools/reel/handler.py index 1ed0470..9be6e63 100644 --- a/eve/tools/reel/handler.py +++ b/eve/tools/reel/handler.py @@ -62,7 +62,7 @@ """ - +from bson.objectid import ObjectId import math import asyncio import tempfile @@ -84,6 +84,9 @@ # from ...tools import load_tool # from ... import voice +from ...tools.elevenlabs.handler import select_random_voice +from ...tool import Tool +from ...mongo import get_collection @@ -306,37 +309,23 @@ class VisualPrompts(BaseModel): -# async def go(): -# speech_audio = await elevenlabs.handler({ -# "text": "this is a test", -# "voice_id": "j6Fbg1nV1BgnjZqPvN1d" -# }, db="STAGE") -# return speech_audio -# import asyncio -# asyncio.run(go()) - -from bson.objectid import ObjectId -async def handler(args: dict, db: str): - - from ...tools.elevenlabs.handler import select_random_voice - from ...tool import Tool - from ...mongo import get_collection - elevenlabs = Tool.load("elevenlabs", db=db) - musicgen = Tool.load("musicgen", db=db) - flux = Tool.load("flux_dev", db=db) - runway = Tool.load("runway", db=db) - video_concat = Tool.load("video_concat", db=db) - audio_video_combine = Tool.load("audio_video_combine", db=db) +async def handler(args: dict): + elevenlabs = Tool.load("elevenlabs") + musicgen = Tool.load("musicgen") + flux = Tool.load("flux_dev") + runway = Tool.load("runway") + video_concat = Tool.load("video_concat") + audio_video_combine = Tool.load("audio_video_combine") instructions = None use_lora = args.get("use_lora", False) if use_lora: lora = args.get("lora") - loras = get_collection("models", db=db) + loras = get_collection("models") lora_doc = loras.find_one({"_id": ObjectId(lora)}) lora_name = lora_doc.get("name") caption_prefix = lora_doc["args"]["caption_prefix"] @@ -360,12 +349,12 @@ async def handler(args: dict, db: str): speech_audio = await elevenlabs.async_run({ "text": reel.voiceover, "voice_id": voice - }, db=db) + }) if speech_audio.get("error"): raise Exception(f"Speech generation failed: {speech_audio['error']}") - speech_audio_url = s3.get_full_url(speech_audio['output'][0]['filename'], db=db) + speech_audio_url = s3.get_full_url(speech_audio['output'][0]['filename']) # download to temp file response = requests.get(speech_audio_url) @@ -396,17 +385,22 @@ async def handler(args: dict, db: str): if args.get("use_music"): + print("music_prompt", args.get("music_prompt")) music_prompt = args.get("music_prompt") or reel.music_prompt + print("music_prompt", music_prompt) + print("run") music_audio = await musicgen.async_run({ "prompt": music_prompt, "duration": int(duration) - }, db=db) + }) + print("run2") + print("music_audio", music_audio) # music_audio = {'output': {'mediaAttributes': {'mimeType': 'audio/mpeg', 'duration': 20.052}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/430eb06b9a9bd66bece456fd3cd10f8c6d99fb75c1d05a1da6c317247ac171c6.mp3'}, 'status': 'completed'} if music_audio.get("error"): raise Exception(f"Music generation failed: {music_audio['error']}") - music_audio = eden_utils.prepare_result(music_audio, db=db) + music_audio = eden_utils.prepare_result(music_audio) print("MUSIC AUDIO 55", music_audio) @@ -434,9 +428,14 @@ async def handler(args: dict, db: str): else: audio = music_audio + print("lfg", audio) + if audio: + print("go1") audio_url, _ = s3.upload_audio_segment(audio) + print("audio_url", audio_url) + print("go2") # get resolution orientation = args.get("orientation") print("TE ORIENTATION IS", orientation) @@ -494,8 +493,8 @@ async def handler(args: dict, db: str): images = [] for i in range(num_clips): - image = await flux.async_run(flux_args[i], db=db) - image = eden_utils.prepare_result(image, db=db) + image = await flux.async_run(flux_args[i]) + image = eden_utils.prepare_result(image) output_url = image['output'][0]["url"] images.append(output_url) # images =['https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/6af97716cf3a4703877576e07823d5c6492a0355c2c7a55148b8f6a4cc8d97a7.png', 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/4bbcee84993883fe767502a29cdbe615e5f16b962de5d92a77e50ca466ef6564.png'] @@ -521,19 +520,19 @@ async def handler(args: dict, db: str): video = await runway.async_run({ "prompt_image": image, "prompt_text": flux_args[i]["prompt"], - "duration": str(durations[i]), + "duration": durations[i], "ratio": "16:9" if orientation == "landscape" else "9:16" - }, db=db) + }) print("video!!", video) - video = eden_utils.prepare_result(video, db=db) + video = eden_utils.prepare_result(video) print("video", video) video = video['output'][0]['url'] videos.append(video) - video = await video_concat.async_run({"videos": videos}, db=db) - video = eden_utils.prepare_result(video, db=db) + video = await video_concat.async_run({"videos": videos}) + video = eden_utils.prepare_result(video) print("video", video) video_url = video['output'][0]['url'] @@ -541,10 +540,10 @@ async def handler(args: dict, db: str): output = await audio_video_combine.async_run({ "audio": audio_url, "video": video_url - }, db=db) + }) print("OUTPTU!") print(output) - final_video = eden_utils.prepare_result(output, db=db) + final_video = eden_utils.prepare_result(output) print(final_video) final_video_url = final_video['output'][0]['url'] print("a 5") @@ -562,301 +561,3 @@ async def handler(args: dict, db: str): } } - -async def handler2(args: dict, db: str): - # try: - if 1: - - vid = voice.select_random_voice("A gruff and intimidating voice") - print("vid", vid) - - - - - prompt = args.get("prompt") - music = args.get("use_music") - music_prompt = (args.get("music_prompt") or "").strip() - - narrator = args.get("use_narrator") - narration = (args.get("narration") or "").strip() if narrator else "" - narration = narration[:600] - if narration: # remove everything after the last space - last_space_idx = narration.rindex(" ") - narration = narration[:last_space_idx] - - min_duration = args.get("min_duration") - - # resolution = args.get("resolution", "none") - # width = args.get("width", None) - # height = args.get("width", None) - - # print("resolution", resolution) - # print("width", width) - # print("height", height) - - - orientation = args.get("orientation") - if orientation == "landscape": - width = 1280 - height = 768 - else: - width = 768 - height = 1280 - - - speech_boost = 5 - - if not min_duration: - raise Exception("min_duration is required") - - print("ALL ARGS ARE", args) - - characters = extract_characters(prompt) - - if narrator: - characters.append(Character(name="narrator", description="The narrator of the reel is a voiceover artist who provides some narration for the reel")) - - print("characters :: ", characters) - - voices = { - c.name: voice.select_random_voice(c.description) - for c in characters - } - - story = write_reel(prompt, characters, narration, music, music_prompt) - - print("story", story) - - duration = min_duration - - print("characters", characters) - print("voices", voices) - print("story", story) - - metadata = { - "reel": story.model_dump(), - "characters": [c.model_dump() for c in characters], - } - - print("metadata", metadata) - - speech_audio = None - music_audio = None - print("NEXT") - # generate speech - print(" ---1-1-1 lets go") - print(voices) - # print(voices[story.speaker]) - - if story.speech: - speech_audio = voice.run( - text=story.speech, - voice_id=voices[story.speaker] - ) - print("generated speech", story.speech) - speech_audio = AudioSegment.from_file(BytesIO(speech_audio)) - silence1 = AudioSegment.silent(duration=2000) - silence2 = AudioSegment.silent(duration=3000) - speech_audio = silence1 + speech_audio + silence2 - duration = max(duration, len(speech_audio) / 1000) - metadata["speech"], _ = s3.upload_audio_segment(speech_audio) - - # # generate music - if music and story.music_prompt: - from eve.tool import Tool - musicgen = Tool.load("musicgen", db="STAGE") - music = await musicgen.async_run({ - "prompt": story.music_prompt, - "duration": int(duration) - }, db=db) - print("THE MUSIC IS DONE!") - print(music) - print("generated music", story.music_prompt) - music_bytes = requests.get(music[0]['url']).content - music_audio = AudioSegment.from_file(BytesIO(music_bytes)) - metadata["music"], _ = s3.upload_audio_segment(music_audio) - - # mix audio - audio = None - if speech_audio and music_audio: - diff_db = ratio_to_db(speech_audio.rms / music_audio.rms) - music_audio = music_audio + diff_db - speech_audio = speech_audio + speech_boost - audio = music_audio.overlay(speech_audio) - elif speech_audio: - audio = speech_audio - elif music_audio: - audio = music_audio - - print("THE AUDIO IS DONE!") - print(audio) - - - - - print("MAKE THE VIDEO!") - - flux_args = { - "prompt": story.image_prompt, - "width": width, - "height": height - } - print("flux_args", flux_args) - use_lora = args.get("use_lora", False) - if use_lora: - lora = args.get("lora") - lora_strength = args.get("lora_strength") - flux_args.update({ - "use_lora": True, - "lora": lora, - "lora_strength": lora_strength - }) - - print("flux_args", flux_args) - - - num_clips = math.ceil(duration / 10) - print("num_clips", num_clips) - - flux_args = [flux_args.copy()] * num_clips - - if num_clips > 1: - prompts = prompt_variations(prompt, num_clips) - print("ORIGINAL PROMPT", prompt) - print("-----") - print("PROMPT VARIATIONS") - for p, new_prompt in enumerate(prompts): - print(p) - print("-----") - flux_args[p]["prompt"] = new_prompt - - - txt2img = load_tool("../../workflows/workspaces/flux/workflows/flux_dev") - images = [] - for i in range(num_clips): - print("i", i) - image = await txt2img.async_run(flux_args[i], db=db) - print("THE IMAGE IS DONE!") - print(image) - output_url = image[0]["url"] - images.append(output_url) - - print("run runway") - runway = load_tool("tools/runway") - - # print("images", images) - - # num_clips = 1 - # images = ["https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/53bc5b8d715c6b243db787ab2ca15718f983dd80811f470f2a8e9aa4c8f518cc.png"] - # orientation = "portrait" - # duration = 5 - - videos = [] - dur = 10 - for i in range(num_clips): - if i == num_clips - 1 and duration % 10 < 5: - dur = 5 - print("video", i) - video = await runway.async_run({ - "prompt_image": images[i], - "prompt_text": "A panorama of a sand castle", #story.image_prompt, - "duration": str(dur), - "ratio": "16:9" if orientation == "landscape" else "9:16" - }, db=db) - print("video is done", i) - print(video) - videos.append(video[0]) - - print("videos", videos) - - # download videos - # videos = [eden_utils.get_file_handler(".mp4", v) for v in videos] - - video_concat = load_tool("tools/media_utils/video_concat") - video = await video_concat.async_run({"videos": [v["url"] for v in videos]}, db=db) - print("video", video) - video = video[0]['url'] - - - # txt2vid = load_tool("../workflows/workspaces/video/workflows/txt2vid") - # video = await txt2vid.async_run({ - # "prompt": story.image_prompt, - # "n_frames": 128, - # "width": width, - # "height": height - # }, db=db) - print("THE VIDEO IS DONE!") - # video = [{'mediaAttributes': {'mimeType': 'video/mp4', 'width': 1280, 'height': 768, 'aspectRatio': 1.6666666666666667, 'duration': 31.6}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/75bf55b76a8e4cadbf824b4eee1673a8c41c24f6688a1d5f2f90723c237c4ae6.mp4'}] - print(video) - # output_url = video[0]["url"] - # output_url = video - # video = "output.mp4" - - # print("txt2vid", output_url) - - print("a 1") - if audio: - print("a 2") - buffer = BytesIO() - print("a 3") - audio.export(buffer, format="mp3") - print("a 4") - # print("URL IS", video[0]["url"]) - output = eden_utils.make_audiovideo_clip(video, buffer) - print(output) - print("a 5") - # output_url, _ = s3.upload_file(output) - print("a 6") - - # print("output_url", output_url) - print("metadata", metadata) - - print("LETS GO!!!! ...") - # print("output_url", output_url) - print("story", story) - print("characters", characters) - print("images", ["images"]) - print("videos", ["videos"]) - print("music", music) - zz = { - "output": output, - "intermediate_outputs": { - "story": story.model_dump(), - "characters": [c.model_dump() for c in characters], - "images": images, - "videos": videos, - "music": music, - # "speech": speech_audio - } - } - - # zz = {'output': '/var/folders/h_/8038q2513yz414f7j3yqy_580000gn/T/tmpkjf59iem.mp4', 'intermediate_outputs': {'story': {'image_prompt': 'A cinematic asteroid view of Mars hurtling through space and colliding dramatically with Earth, causing an immense explosion.', 'music_prompt': 'Intense orchestral music building to a crescendo, evoking tension and epic disaster.', 'speaker': 'narrator', 'speech': "Witness the catastrophic collision of Mars and Earth, a cosmic dance of destruction, captured with stunning simulation, as the red planet meets our blue world in an inevitable, fiery embrace. Watch as continents crumble and atmospheres collide, forever altering the solar system's story."}, 'characters': [{'name': 'narrator', 'description': 'The narrator of the reel is a voiceover artist who provides some narration for the reel'}], 'images': ['https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/6532b48aa71c98b56a9ab41f63a24c09029527360af26b9e089218de4043e8f8.png', 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/e423f8290876ee4694f811bb1716e5d70acdf6ab6b6ea3480357ca5ae6af2f2b.png', 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/3892e536589147b729ff8d31ae93457f24361cb01450de707a700ef798828bc8.png'], 'videos': [{'mediaAttributes': {'mimeType': 'video/mp4', 'width': 1280, 'height': 768, 'aspectRatio': 1.6666666666666667, 'duration': 10.54}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/5020c31bf1fdf2f590113a75148a021aae38eb809532d2799b9c434f3548f832.mp4'}, {'mediaAttributes': {'mimeType': 'video/mp4', 'width': 1280, 'height': 768, 'aspectRatio': 1.6666666666666667, 'duration': 10.54}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/2559354dcfbfe2921e468580e8ed66823f332924191bdeb5f6aed4d3ae4a19ba.mp4'}, {'mediaAttributes': {'mimeType': 'video/mp4', 'width': 1280, 'height': 768, 'aspectRatio': 1.6666666666666667, 'duration': 10.54}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/1ddbcdbfaa3c4a8ab218a79cbbf1d95f92cc105b0d5e30fe8c5cd0bc8f00bfa4.mp4'}], 'music': [{'mediaAttributes': {'mimeType': 'audio/mpeg', 'duration': 28.044}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/e3b1438800d80293a2cc87a6371cd6947ad9e10bd449b5bfe27e4891dbab9448.mp3'}]}} - - # zz = {'output': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/911d8cbe1775cfa52ddf3900fa2d5e55698de63860eb00a4be246baf5c174912.mp4', 'intermediate_outputs': {'story': {'image_prompt': 'A dramatic simulation showing Mars approaching and colliding with Earth, with both planets breaking apart and creating a cosmic explosion.', 'music_prompt': "213413", 'speaker': "None222", 'speech': "2342"}, 'characters': ["SDFA"], 'images': ['images'], 'videos': ['videos'], 'music': "ddd"}} - - # zz = { - # 'output': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/911d8cbe1775cfa52ddf3900fa2d5e55698de63860eb00a4be246baf5c174912.mp4', - # 'intermediate_outputs': { - # 'story': {'image_prompt': 'A dramatic simulation showing Mars approaching and colliding with Earth, with both planets breaking apart and creating a cosmic explosion.'} - # } - # } - - print("zz", zz) - - from pprint import pprint - pprint(zz) - - return zz - - - # except asyncio.CancelledError as e: - # print("asyncio CancelledError") - # print(e) - # except Exception as e: - # print("normal error") - # print(e) - - -# import eve.eden_utils -# zz = {'output': '/Users/gene/Eden/dev/eve/97468b465a993c272b8d12990095027ec67f86ddfea6093c36be8925503d41a4.mp4', 'intermediate_outputs': {'story': {'image_prompt': 'A cinematic asteroid view of Mars hurtling through space and colliding dramatically with Earth, causing an immense explosion.', 'music_prompt': 'Intense orchestral music building to a crescendo, evoking tension and epic disaster.', 'speaker': 'narrator', 'speech': "Witness the catastrophic collision of Mars and Earth, a cosmic dance of destruction, captured with stunning simulation, as the red planet meets our blue world in an inevitable, fiery embrace. Watch as continents crumble and atmospheres collide, forever altering the solar system's story."}, 'characters': [{'name': 'narrator', 'description': 'The narrator of the reel is a voiceover artist who provides some narration for the reel'}], 'images': ['https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/6532b48aa71c98b56a9ab41f63a24c09029527360af26b9e089218de4043e8f8.png', 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/e423f8290876ee4694f811bb1716e5d70acdf6ab6b6ea3480357ca5ae6af2f2b.png', 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/3892e536589147b729ff8d31ae93457f24361cb01450de707a700ef798828bc8.png'], 'videos': [{'mediaAttributes': {'mimeType': 'video/mp4', 'width': 1280, 'height': 768, 'aspectRatio': 1.6666666666666667, 'duration': 10.54}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/5020c31bf1fdf2f590113a75148a021aae38eb809532d2799b9c434f3548f832.mp4'}, {'mediaAttributes': {'mimeType': 'video/mp4', 'width': 1280, 'height': 768, 'aspectRatio': 1.6666666666666667, 'duration': 10.54}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/2559354dcfbfe2921e468580e8ed66823f332924191bdeb5f6aed4d3ae4a19ba.mp4'}, {'mediaAttributes': {'mimeType': 'video/mp4', 'width': 1280, 'height': 768, 'aspectRatio': 1.6666666666666667, 'duration': 10.54}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/1ddbcdbfaa3c4a8ab218a79cbbf1d95f92cc105b0d5e30fe8c5cd0bc8f00bfa4.mp4'}], 'music': [{'mediaAttributes': {'mimeType': 'audio/mpeg', 'duration': 28.044}, 'url': 'https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/e3b1438800d80293a2cc87a6371cd6947ad9e10bd449b5bfe27e4891dbab9448.mp3'}]}} -# eve.eden_utils.upload_result(zz, "STAGE") diff --git a/eve/tools/replicate_tool.py b/eve/tools/replicate_tool.py index a17ead2..ceb6deb 100644 --- a/eve/tools/replicate_tool.py +++ b/eve/tools/replicate_tool.py @@ -24,10 +24,10 @@ class ReplicateTool(Tool): output_handler: str = "normal" @Tool.handle_run - async def async_run(self, args: Dict, db: str): + async def async_run(self, args: Dict): check_replicate_api_token() - args = self._format_args_for_replicate(args) if self.version: + args = self._format_args_for_replicate(args) prediction = await self._create_prediction(args, webhook=False) prediction.wait() if self.output_handler == "eden": @@ -41,10 +41,11 @@ async def async_run(self, args: Dict, db: str): result = {"output": prediction.output} else: replicate_model = self._get_replicate_model(args) + args = self._format_args_for_replicate(args) result = { "output": replicate.run(replicate_model, input=args) } - result = eden_utils.upload_result(result, db=db) + result = eden_utils.upload_result(result) return result @Tool.handle_start_task @@ -103,29 +104,32 @@ def _format_args_for_replicate(self, args: dict): is_number = parameter.get('type') in ['integer', 'float'] alias = parameter.get('alias') lora = parameter.get('type') == 'lora' + if field in new_args: if lora: - lora_doc = get_collection(Model.collection_name, db=self.db).find_one({"_id": ObjectId(args[field])}) if args[field] else None + loras = get_collection(Model.collection_name) + lora_doc = loras.find_one({"_id": ObjectId(args[field])}) if args[field] else None if lora_doc: - lora_url = lora_doc.get("checkpoint") + lora_url = s3.get_full_url(lora_doc.get("checkpoint")) lora_name = lora_doc.get("name") - caption_prefix = lora_doc.get("args", {}).get("caption_prefix") + lora_trigger_text = lora_doc.get("lora_trigger_text") new_args[field] = lora_url if "prompt" in new_args: pattern = re.compile(re.escape(lora_name), re.IGNORECASE) - new_args["prompt"] = pattern.sub(caption_prefix, new_args['prompt']) + new_args["prompt"] = pattern.sub(lora_trigger_text, new_args['prompt']) if is_number: new_args[field] = float(args[field]) elif is_array: new_args[field] = "|".join([str(p) for p in args[field]]) if alias: new_args[alias] = new_args.pop(field) + return new_args def _get_replicate_model(self, args: dict): """Use default model or a substitute model conditional on an arg""" replicate_model = self.replicate_model - + if self.replicate_model_substitutions: for cond, model in self.replicate_model_substitutions.items(): if args.get(cond): @@ -135,11 +139,10 @@ def _get_replicate_model(self, args: dict): async def _create_prediction(self, args: dict, webhook=True): replicate_model = self._get_replicate_model(args) - user, model = replicate_model.split('/', 1) - + user, model = replicate_model.split('/', 1) webhook_url = get_webhook_url() if webhook else None webhook_events_filter = ["start", "completed"] if webhook else None - + if self.version == "deployment": deployment = await replicate.deployments.async_get(f"{user}/{model}") prediction = await deployment.predictions.async_create( @@ -170,18 +173,12 @@ def replicate_update_task(task: Task, status, error, output, output_handler): if status == "failed": task.update(status="failed", error=error) - n_samples = task.args.get("n_samples", 1) - refund_amount = (task.cost or 0) * (n_samples - len(task.result or [])) / n_samples - user = User.from_mongo(task.user, db=task.db) - user.refund_manna(refund_amount) + task.refund_manna() return {"status": "failed", "error": error} elif status == "canceled": task.update(status="cancelled") - n_samples = task.args.get("n_samples", 1) - refund_amount = (task.cost or 0) * (n_samples - len(task.result or [])) / n_samples - user = User.from_mongo(task.user, db=task.db) - user.refund_manna(refund_amount) + task.refund_manna() return {"status": "cancelled"} elif status == "processing": @@ -195,10 +192,10 @@ def replicate_update_task(task: Task, status, error, output, output_handler): if output_handler in ["eden", "trainer"]: thumbnails = output[-1]["thumbnails"] output = output[-1]["files"] - output = eden_utils.upload_result(output, db=task.db, save_thumbnails=True, save_blurhash=True) + output = eden_utils.upload_result(output, save_thumbnails=True, save_blurhash=True) result = [{"output": [out]} for out in output] else: - output = eden_utils.upload_result(output, db=task.db, save_thumbnails=True, save_blurhash=True) + output = eden_utils.upload_result(output, save_thumbnails=True, save_blurhash=True) result = [{"output": [out]} for out in output] for r, res in enumerate(result): @@ -207,11 +204,10 @@ def replicate_update_task(task: Task, status, error, output, output_handler): filename = output["filename"] thumbnail = eden_utils.upload_media( thumbnails[0], - db=task.db, save_thumbnails=False, save_blurhash=False ) if thumbnails else None - url = s3.get_full_url(filename, db=task.db) + url = s3.get_full_url(filename) model = Model( name=task.args["name"], user=task.user, @@ -236,7 +232,7 @@ def replicate_update_task(task: Task, status, error, output, output_handler): mediaAttributes=output["mediaAttributes"], name=name ) - creation.save(db=task.db) + creation.save() result[r]["output"][o]["creation"] = creation.id run_time = (datetime.now(timezone.utc) - task.createdAt).total_seconds() diff --git a/eve/tools/runway/api.yaml b/eve/tools/runway/api.yaml index ab023f1..1ec5d07 100644 --- a/eve/tools/runway/api.yaml +++ b/eve/tools/runway/api.yaml @@ -3,7 +3,7 @@ description: Text-guided, realistic image animation with Runway Gen3a tip: |- This tool can be used for creating a realistic animation of an image. Specific camera motion can be obtained by putting such directions in the prompt text. thumbnail: app/runway-tree-orb-woman2-opt.mp4 -cost_estimate: 10 * int(duration) +cost_estimate: 10 * duration output_type: video base_mopel: runway status: prod @@ -19,11 +19,11 @@ parameters: description: The prompt to guide the animation required: true duration: - type: string + type: integer label: Duration description: The duration of the video in seconds - default: '5' - choices: ['5', '10'] + default: 5 + choices: [5, 10] ratio: type: string label: Ratio diff --git a/eve/tools/runway/handler.py b/eve/tools/runway/handler.py index d99396a..03290c3 100644 --- a/eve/tools/runway/handler.py +++ b/eve/tools/runway/handler.py @@ -9,16 +9,22 @@ """ -async def handler(args: dict, db: str): +async def handler(args: dict): client = RunwayML() + try: + ratio = "1280:768" if args["ratio"] == "16:9" else "768:1280" + task = client.image_to_video.create( model='gen3a_turbo', prompt_image=args["prompt_image"], - prompt_text=args["prompt_text"][:512] + prompt_text=args["prompt_text"][:512], + duration=int(args["duration"]), + ratio=ratio, + watermark=False ) except runwayml.APIConnectionError as e: print("The server could not be reached") @@ -57,6 +63,15 @@ async def handler(args: dict, db: str): print("task finished2", task.status) print(task) + + """ + + task finished2 FAILED +TaskRetrieveResponse(id='48947b97-c260-492e-b662-bec5aa725ebf', created_at=datetime.datetime(2025, 1, 1, 20, 43, 5, 303000, tzinfo=datetime.timezone.utc), status='FAILED', failure='An unexpected error occurred.', failure_code='INTERNAL.BAD_OUTPUT.CODE01', output=None, progress=None, createdAt='2025-01-01T20:43:05.303Z', failureCode='INTERNAL.BAD_OUTPUT.CODE01') +Error An unexpected error occurred. + + """ + if task.status == "FAILED": print("Error", task.failure) raise Exception(task.failure) diff --git a/eve/tools/twitter/__init__.py b/eve/tools/twitter/__init__.py index be2f70b..10fada7 100644 --- a/eve/tools/twitter/__init__.py +++ b/eve/tools/twitter/__init__.py @@ -1,6 +1,8 @@ import time import logging import requests +from datetime import datetime, timedelta +from dotenv import load_dotenv from requests_oauthlib import OAuth1Session from ...agent import Agent @@ -20,6 +22,9 @@ def __init__(self, agent: Agent): self.last_processed_id = None self.oauth = self._init_oauth_session() + print("GET BEARER TOKEN") + print(self.bearer_token) + def _init_oauth_session(self): """Initializes OAuth1 session.""" return OAuth1Session( @@ -69,6 +74,21 @@ def fetch_mentions(self): ) return response.json() if response else {} + + def fetch_followings(self): + """Fetches the latest followings of the user.""" + response = self._make_request( + 'post', + f"https://api.twitter.com/2/users/{self.user_id}/following", + headers={"Authorization": f"Bearer {self.bearer_token}"}, + params={} + ) + return response.json() if response else {} + + + + + def get_newest_tweet(self, data): """Gets the newest tweet from the data.""" @@ -244,3 +264,156 @@ def post(self, tweet_text, media_urls=None, reply_to_tweet_id=None): else: logging.error("Failed to post tweet: None response from _make_request.") raise Exception("Failed to post tweet. See logs for details.") + + + + + + + def get_following222(self, usernames): + """Fetches the list of accounts each specified username is following.""" + following_data = {} + + for username in usernames: + response = self._make_request( + 'get', + f"https://api.twitter.com/2/users/by/username/{username}", + headers={"Authorization": f"Bearer {self.bearer_token}"} + ) + + if not response: + logging.error(f"Failed to fetch user info for {username}.") + following_data[username] = [] + continue + + user_id = response.json().get("data", {}).get("id") + + if not user_id: + logging.error(f"User ID not found for {username}.") + following_data[username] = [] + continue + + follows_response = self._make_request( + 'get', + f"https://api.twitter.com/2/users/{user_id}/following", + headers={"Authorization": f"Bearer {self.bearer_token}"}, + params={"max_results": 1000} # Adjust as needed for pagination. + ) + + if follows_response: + following_data[username] = [ + follow.get("username") for follow in follows_response.json().get("data", []) + ] + else: + following_data[username] = [] + + return following_data + + def get_recent_tweets(self, usernames, timeframe_minutes=60): + """Fetches tweets from the given users within the specified timeframe.""" + recent_tweets = {} + time_threshold = datetime.utcnow() - timedelta(minutes=timeframe_minutes) + + for username in usernames: + response = self._make_request( + 'get', + f"https://api.twitter.com/2/users/by/username/{username}", + headers={"Authorization": f"Bearer {self.bearer_token}"} + ) + + if not response: + logging.error(f"Failed to fetch user info for {username}.") + recent_tweets[username] = [] + continue + + user_id = response.json().get("data", {}).get("id") + + if not user_id: + logging.error(f"User ID not found for {username}.") + recent_tweets[username] = [] + continue + + tweets_response = self._make_request( + 'get', + f"https://api.twitter.com/2/users/{user_id}/tweets", + headers={"Authorization": f"Bearer {self.bearer_token}"}, + params={"max_results": 100, "tweet.fields": "created_at"} + ) + + if not tweets_response: + logging.error(f"Failed to fetch tweets for {username}.") + recent_tweets[username] = [] + continue + + tweets = tweets_response.json().get("data", []) + recent_tweets[username] = [ + tweet for tweet in tweets + if datetime.strptime(tweet["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ") >= time_threshold + ] + + return recent_tweets + + + + + + + def get_all_followings(self, user_ids, max_results=1000): + """ + Retrieves all followings for each user in user_ids using the Twitter v2 endpoint. + Returns a dict mapping user_id -> list_of_following_users. + + Note: Each user requires its own call. If a user has a large following list, + we will paginate until we've retrieved them all, using 'next_token'. + """ + url_template = "https://api.twitter.com/2/users/{}/following" + headers = { + "Authorization": f"Bearer {self.bearer_token}", + "User-Agent": "v2UserFollowingLookupPython" + } + + all_followings = {} + + for user_id in user_ids: + followings = [] + pagination_token = None + + while True: + params = { + "max_results": max_results + } + if pagination_token: + params["pagination_token"] = pagination_token + + # Use _make_request but set oauth=False to use Bearer token + # (the v2 'following' endpoint typically uses Bearer token). + response = self._make_request( + "post", + url_template.format(user_id), + oauth=False, + headers=headers, + params=params + ) + + if not response: + logging.error(f"Error fetching followings for user {user_id}") + break + + data = response.json() + # If Twitter returns an error structure, log it + if "errors" in data: + logging.error(f"Error fetching followings for user {user_id}: {data}") + break + + followings_page = data.get("data", []) + followings.extend(followings_page) + + meta = data.get("meta", {}) + pagination_token = meta.get("next_token") + if not pagination_token: + break + + # Store all followings for this user + all_followings[user_id] = followings + + return all_followings \ No newline at end of file diff --git a/eve/tools/twitter/get_tweets/handler.py b/eve/tools/twitter/get_tweets/handler.py index 6507ebc..18df6d7 100644 --- a/eve/tools/twitter/get_tweets/handler.py +++ b/eve/tools/twitter/get_tweets/handler.py @@ -6,7 +6,7 @@ -async def handler(args: dict, db: str): +async def handler(args: dict): pass # mentions = X_client.fetch_mentions() # mentions = {'output': ['{"data": [{"text": "Hey buddy @SillySmile21038 @Akhil_lavadya @AjayVasava73118 @rsalusse @MelikianRa40150 @rajkumarnahar13 @elonfan1854 @Ernesto25954832 @US_Hot_Dog @JeremiahEr30464 https://t.co/RIgXToQVLM", "author_id": "1728653464601759744", "edit_history_tweet_ids": ["1814809535283916828"], "id": "1814809535283916828"}, {"text": "@cinebuzzbr @marinavgregory @rsalusse @paramountplusbr T\\u00e1 errado meu @ HAHAHAHAHH", "author_id": "1424172460153442308", "edit_history_tweet_ids": ["1533944964685635585"], "id": "1533944964685635585"}, {"text": "@rsalusse Nosso Amigo @ufc_matogrosso est\\u00e1 vendendo 4 ingressos para o UFC 142. Interessados add no msn diego_torrezini@live.com rf", "author_id": "456950209", "edit_history_tweet_ids": ["155611948959019009"], "id": "155611948959019009"}], "includes": {"users": [{"id": "1728653464601759744", "name": "idnaniotkirb", "username": "idnaniotki15704"}, {"id": "1424172460153442308", "name": "Bibi Lambe-Picas", "username": "servidorrrr"}, {"id": "456950209", "name": "Vivane Lins", "username": "VLins24625nes"}]}, "meta": {"result_count": 3, "newest_id": "1814809535283916828", "oldest_id": "155611948959019009"}}'], 'status': 'completed'} diff --git a/eve/tools/twitter/tweet/handler.py b/eve/tools/twitter/tweet/handler.py index 8f901a4..2580b19 100644 --- a/eve/tools/twitter/tweet/handler.py +++ b/eve/tools/twitter/tweet/handler.py @@ -3,8 +3,8 @@ from .. import X -async def handler(args: dict, db: str): - agent = Agent.load(args["agent"], db=db) +async def handler(args: dict): + agent = Agent.load(args["agent"]) x = X(agent) diff --git a/eve/tools/wallet/send_eth/handler.py b/eve/tools/wallet/send_eth/handler.py index 7a9a966..0bd7573 100644 --- a/eve/tools/wallet/send_eth/handler.py +++ b/eve/tools/wallet/send_eth/handler.py @@ -2,7 +2,7 @@ from web3 import Web3 -async def handler(args: dict, db: str): +async def handler(args: dict): # Initialize Web3 with Base Sepolia RPC URL w3 = Web3(Web3.HTTPProvider(os.getenv("BASE_SEPOLIA_RPC_URL"))) diff --git a/eve/user.py b/eve/user.py index aea8970..9423151 100644 --- a/eve/user.py +++ b/eve/user.py @@ -4,6 +4,7 @@ from .mongo import Document, Collection, get_collection, MongoDocumentNotFound + @Collection("mannas") class Manna(Document): user: ObjectId @@ -11,16 +12,15 @@ class Manna(Document): subscriptionBalance: float = 0 @classmethod - def load(cls, user: ObjectId | str, db=None): + def load(cls, user: ObjectId): try: - user = ObjectId(user) if isinstance(user, str) else user - return super().load(user=user, db=db) + return super().load(user=user) except MongoDocumentNotFound as e: # if mannas not found, check if user exists, and create a new manna document - user = User.from_mongo(user, db=db) + user = User.from_mongo(user) if not user: raise Exception(f"User {user} not found") - manna = Manna(user=user.id, db=db) + manna = Manna(user=user.id) manna.save() return manna except Exception as e: @@ -41,6 +41,14 @@ def refund(self, amount: float): self.save() +@Collection("transactions") +class Transaction(Document): + manna: ObjectId + task: ObjectId + amount: float + type: Literal["spend", "refund"] + + @Collection("users3") class User(Document): # todo underscore @@ -77,82 +85,67 @@ class User(Document): farcasterUsername: Optional[str] = None def check_manna(self, amount: float): - manna = Manna.load(self.id, db=self.db) + manna = Manna.load(self.id) total_balance = manna.balance + manna.subscriptionBalance if total_balance < amount: raise Exception( f"Insufficient manna balance. Need {amount} but only have {total_balance}" ) - def spend_manna(self, amount: float): - if amount == 0: - return - manna = Manna.load(self.id, db=self.db) - manna.spend(amount) - - def refund_manna(self, amount: float): - if amount == 0: - return - manna = Manna.load(self.id, db=self.db) - manna.refund(amount) - @classmethod - def from_discord(cls, discord_id, discord_username, db): + def from_discord(cls, discord_id, discord_username): discord_id = str(discord_id) discord_username = str(discord_username) - users = get_collection(cls.collection_name, db=db) + users = get_collection(cls.collection_name) user = users.find_one({"discordId": discord_id}) if not user: - username = cls._get_unique_username(f"discord_{discord_username}", db=db) + username = cls._get_unique_username(f"discord_{discord_username}") new_user = cls( - db=db, discordId=discord_id, discordUsername=discord_username, username=username, ) new_user.save() return new_user - return cls(**user, db=db) + return cls(**user) @classmethod - def from_farcaster(cls, farcaster_id, farcaster_username, db): + def from_farcaster(cls, farcaster_id, farcaster_username): farcaster_id = str(farcaster_id) farcaster_username = str(farcaster_username) - users = get_collection(cls.collection_name, db=db) + users = get_collection(cls.collection_name) user = users.find_one({"farcasterId": farcaster_id}) if not user: - username = cls._get_unique_username(f"farcaster_{farcaster_username}", db=db) + username = cls._get_unique_username(f"farcaster_{farcaster_username}") new_user = cls( - db=db, farcasterId=farcaster_id, farcasterUsername=farcaster_username, username=username, ) new_user.save() return new_user - return cls(**user, db=db) + return cls(**user) @classmethod - def from_telegram(cls, telegram_id, telegram_username, db): + def from_telegram(cls, telegram_id, telegram_username): telegram_id = str(telegram_id) telegram_username = str(telegram_username) - users = get_collection(cls.collection_name, db=db) + users = get_collection(cls.collection_name) user = users.find_one({"telegramId": telegram_id}) if not user: - username = cls._get_unique_username(f"telegram_{telegram_username}", db=db) + username = cls._get_unique_username(f"telegram_{telegram_username}") new_user = cls( - db=db, telegramId=telegram_id, telegramUsername=telegram_username, username=username, ) new_user.save() return new_user - return cls(**user, db=db) + return cls(**user) @classmethod - def _get_unique_username(cls, base_username, db): - users = get_collection(cls.collection_name, db=db) + def _get_unique_username(cls, base_username): + users = get_collection(cls.collection_name) username = base_username counter = 2 while users.find_one({"username": username}): diff --git a/modal_tool.py b/modal_tool.py index 1ceb219..f4d185e 100644 --- a/modal_tool.py +++ b/modal_tool.py @@ -31,11 +31,11 @@ ) @app.function(image=image, timeout=3600) -async def run(tool_key: str, args: dict, db: str): - result = await handlers[tool_key](args, db=db) - return eden_utils.upload_result(result, db=db) +async def run(tool_key: str, args: dict): + result = await handlers[tool_key](args) + return eden_utils.upload_result(result) @app.function(image=image, timeout=3600) @task_handler_func -async def run_task(tool_key: str, args: dict, db: str): - return await handlers[tool_key](args, db=db) +async def run_task(tool_key: str, args: dict): + return await handlers[tool_key](args) diff --git a/pyproject.toml b/pyproject.toml index 7140fb1..438810e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,8 +44,7 @@ dependencies = [ "farcaster>=0.7.11", "ably>=2.0.7", "colorama>=0.4.6", - "diffusers==0.31.0", - "web3<7.6.1" + "web3<7.6.1", ] [build-system] diff --git a/requirements-dev.lock b/requirements-dev.lock index 151f318..ccdb25d 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -460,7 +460,6 @@ typing-extensions==4.12.2 # via anthropic # via anyio # via elevenlabs - # via eth-rlp # via eth-typing # via fastapi # via modal diff --git a/requirements.lock b/requirements.lock index e2607c0..d675832 100644 --- a/requirements.lock +++ b/requirements.lock @@ -431,7 +431,6 @@ typing-extensions==4.12.2 # via anthropic # via anyio # via elevenlabs - # via eth-rlp # via eth-typing # via fastapi # via modal diff --git a/tests/test_client.py b/tests/test_client.py index 5505e01..36081b1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,19 +1,12 @@ -#uvicorn eve.api:web_app --host 0.0.0.0 --port 8000 --reload - import os import json import time import subprocess import requests -import eve -#from dotenv import load_dotenv -#load_dotenv(os.path.expanduser("~/.eve")) - EDEN_ADMIN_KEY = os.getenv("EDEN_ADMIN_KEY") headers = { - # "X-Api-Key": api_key, "Authorization": f"Bearer {EDEN_ADMIN_KEY}", "Content-Type": "application/json", } @@ -39,7 +32,6 @@ def run_chat(server_url): "agent_id": "675fd3c379e00297cdac16fb", "user_message": { "content": "verdelis make a picture of yourself on the beach. use flux_dev_lora and make sure to mention 'Verdelis' in the prompt", - # "content": "make a high quality picture of a fancy cat in your favorite location. use flux dev", } } response = requests.post(server_url+"/chat", json=request, headers=headers) @@ -47,26 +39,23 @@ def run_chat(server_url): print(json.dumps(response.json(), indent=2)) -def test_client(): - run_server = False +def test_client(): + server_url = None try: - if run_server: - # uvicorn eve.api:web_app --host 0.0.0.0 --port 8000 --reload + if not server_url: + print("Starting server...") server = subprocess.Popen( - ["uvicorn", "eve.api:web_app", "--host", "0.0.0.0", "--port", "8000", "--reload"], + ["rye", "run", "eve", "api", "--db", "STAGE"], stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - time.sleep(2) + time.sleep(3) server_url = "http://localhost:8000" - else: - server_url = "https://edenartlab--api-stage-fastapi-app-dev.modal.run" - # server_url = "http://localhost:8000" print("server_url", server_url) print("\nRunning create test...") - # run_create(server_url) + run_create(server_url) print("\nRunning chat test...") run_chat(server_url) @@ -80,7 +69,3 @@ def test_client(): server.terminate() server.wait() - -if __name__ == "__main__": - test_client() - diff --git a/tests/test_llm.py b/tests/test_llm.py index a90077a..597c617 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -7,9 +7,9 @@ # todo: since prompt_thread handles exceptions, this won't actually fail if there are errors def test_prompting(): - user = get_my_eden_user(db="STAGE") + user = get_my_eden_user() - agent = Agent.load("eve", db="STAGE") + agent = Agent.load("eve") tools = agent.get_tools() thread = agent.request_thread() @@ -20,7 +20,6 @@ def test_prompting(): ] for msg in prompt_thread( - db="STAGE", user=user, agent=agent, thread=thread, @@ -32,7 +31,7 @@ def test_prompting(): def test_prompting2(): - user = get_my_eden_user(db="STAGE") + user = get_my_eden_user() messages = [ UserMessage(name="jim", content="i have an apple."), @@ -41,12 +40,11 @@ def test_prompting2(): UserMessage(name="kate", content="what is my name?"), ] - agent = Agent.load("eve", db="STAGE") + agent = Agent.load("eve") tools = agent.get_tools() thread = agent.request_thread() for msg in prompt_thread( - db="STAGE", user=user, agent=agent, thread=thread, diff --git a/tests/test_mongo.py b/tests/test_mongo.py index 4e0a6ae..c7aacf8 100644 --- a/tests/test_mongo.py +++ b/tests/test_mongo.py @@ -1,6 +1,6 @@ """ Todo: -VersionableMongoModel.load(t1.id, collection_name="stories", db="STAGE") +VersionableMongoModel.load(t1.id, collection_name="stories") -> schema = recreate_base_model(document['schema']) * this works but strong typing is not working. @@ -38,7 +38,6 @@ class MongoModelTest(Document): user: ObjectId t = MongoModelTest( - db="STAGE", num=2, args={"foo": "bar"}, user=ObjectId("666666663333366666666666") @@ -46,10 +45,9 @@ class MongoModelTest(Document): t.save() - t2 = MongoModelTest.from_mongo(t.id, db="STAGE") + t2 = MongoModelTest.from_mongo(t.id) assert t2 == MongoModelTest( - db="STAGE", num=2, args={"foo": "bar"}, user=ObjectId("666666663333366666666666"), @@ -61,11 +59,11 @@ class MongoModelTest(Document): # t2.update(invalid_arg="this is ignored", num=7, args={"foo": "hello world"}) t2.update(num=7, args={"foo": "hello world"}) - t3 = MongoModelTest.from_mongo(t2.id, db="STAGE") + t3 = MongoModelTest.from_mongo(t2.id) assert t.id == t2.id == t3.id - assert t3 == MongoModelTest(db="STAGE", num=7, args={"foo": "hello world"}, user=ObjectId("666666663333366666666666"), id=t2.id, createdAt=t3.createdAt, updatedAt=t3.updatedAt) + assert t3 == MongoModelTest(num=7, args={"foo": "hello world"}, user=ObjectId("666666663333366666666666"), id=t2.id, createdAt=t3.createdAt, updatedAt=t3.updatedAt) @@ -85,7 +83,6 @@ def _test_versionable_base_model(): base_model_field=InnerModel(string_field="test5", number_field=7) ), collection_name="stories", - db="STAGE" ) t1.save() @@ -118,7 +115,7 @@ def _test_versionable_base_model(): print("T2 a") print(t1.id) - t2 = VersionableMongoModel.load(t1.id, collection_name="stories", db="STAGE") + t2 = VersionableMongoModel.load(t1.id, collection_name="stories") print(t2) print("T2 b") @@ -143,7 +140,7 @@ def _test_versionable_base_model(): # t2.save() # print("T3 a") - # t3 = VersionableMongoModel.load(t1.id, collection_name="stories", db="STAGE") + # t3 = VersionableMongoModel.load(t1.id, collection_name="stories") # print(t3) # print("T3 b") @@ -166,7 +163,7 @@ def _test_versionable_base_model(): # t3.save() - # t4 = VersionableMongoModel.load(t1.id, collection_name="stories", db="STAGE") + # t4 = VersionableMongoModel.load(t1.id, collection_name="stories") # assert t4.current.model_dump() == t3_expected.model_dump() diff --git a/tests/test_tools.py b/tests/test_tools.py index faadf6c..b577672 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -5,28 +5,27 @@ from eve.auth import get_my_eden_user -async def async_run_tool(tool, api: bool, db: str, mock: bool): +async def async_run_tool(tool, api: bool, mock: bool): """Run a single tool test""" if api: - user = get_my_eden_user(db=db) - task = await tool.async_start_task(user.id, user.id, tool.test_args, db=db, mock=mock) + user = get_my_eden_user() + task = await tool.async_start_task(user.id, user.id, tool.test_args, mock=mock) return await tool.async_wait(task) - return await tool.async_run(tool.test_args, db=db, mock=mock) + return await tool.async_run(tool.test_args, mock=mock) async def async_run_all_tools( tools: list[str], yaml: bool = False, - db: str = "STAGE", api: bool = False, parallel: bool = True, mock: bool = True ): """Test multiple tools with their test args""" # Get tools from either yaml files or mongo - tool_dict = get_tools_from_api_files(tools=tools, include_inactive=True) if yaml else get_tools_from_mongo(db=db, tools=tools) + tool_dict = get_tools_from_api_files(tools=tools, include_inactive=True) if yaml else get_tools_from_mongo(tools=tools) # Create and run tasks - tasks = [async_run_tool(tool, api, db, mock) for tool in tool_dict.values()] + tasks = [async_run_tool(tool, api, mock) for tool in tool_dict.values()] results = await asyncio.gather(*tasks) if parallel else [await task for task in tasks] # Collect errors @@ -55,7 +54,6 @@ def test_tools(): "elevenlabs" ], yaml=False, - db="STAGE", api=False, parallel=True, mock=True @@ -66,7 +64,6 @@ def test_tools(): results = asyncio.run(async_run_all_tools( tools=["legacy_create"], yaml=True, - db="STAGE", api=False, parallel=True, mock=True