Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update _id from /create #46

Merged
merged 21 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 68 additions & 49 deletions comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@
test_all = True if os.getenv("TEST_ALL") else False
skip_tests = os.getenv("SKIP_TESTS")

print("========================================")
print(f"db: {db}")
print(f"workspace: {workspace_name}")
print(f"test_workflows: {test_workflows}")
print(f"test_all: {test_all}")
print(f"skip_tests: {skip_tests}")
print("========================================")

def install_comfyui():
snapshot = json.load(open("/root/workspace/snapshot.json", 'r'))
comfyui_commit_sha = snapshot["comfyui"]
Expand Down Expand Up @@ -218,6 +226,7 @@ def download_files(force_redownload=False):
.env({"TEST_ALL": os.getenv("TEST_ALL")})
.apt_install("git", "git-lfs", "libgl1-mesa-glx", "libglib2.0-0", "libmagic1", "ffmpeg", "libegl1")
.pip_install_from_pyproject(str(root_dir / "pyproject.toml"))
.pip_install("diffusers==0.31.0")
.env({"WORKSPACE": workspace_name})
.copy_local_file(f"{root_workflows_folder}/workspaces/{workspace_name}/snapshot.json", "/root/workspace/snapshot.json")
.copy_local_file(f"{root_workflows_folder}/workspaces/{workspace_name}/downloads.json", "/root/workspace/downloads.json")
Expand Down Expand Up @@ -328,6 +337,9 @@ def test_workflows(self):
if not all([w in workflow_names for w in test_workflows]):
raise Exception(f"One or more invalid workflows found: {', '.join(test_workflows)}")
workflow_names = test_workflows
print(f"====> Running tests for subset of workflows: {' | '.join(workflow_names)}")
else:
print(f"====> Running tests for all workflows: {' | '.join(workflow_names)}")

if not workflow_names:
raise Exception("No workflows found!")
Expand All @@ -338,7 +350,7 @@ def test_workflows(self):
tests = glob.glob(f"/root/workspace/workflows/{workflow}/test*.json")
else:
tests = [f"/root/workspace/workflows/{workflow}/test.json"]
print("Running tests: ", tests)
print(f"====> Running tests for {workflow}: ", tests)
for test in tests:
tool = Tool.from_yaml(f"/root/workspace/workflows/{workflow}/api.yaml")
if tool.status == "inactive":
Expand All @@ -347,10 +359,12 @@ def test_workflows(self):
test_args = json.loads(open(test, "r").read())
test_args = tool.prepare_args(test_args)
test_name = f"{workflow}_{os.path.basename(test)}"
print(f"Running test: {test_name}")
print(f"====> Running test: {test_name}")
t1 = time.time()
result = self._execute(workflow, test_args)
result = eden_utils.upload_result(result)
result = eden_utils.prepare_result(result)
print(f"====> Final media url: {result}")
t2 = time.time()
results[test_name] = result
results["_performance"][test_name] = t2 - t1
Expand Down Expand Up @@ -460,20 +474,22 @@ def _inject_embedding_mentions_sdxl(self, text, embedding_trigger, embeddings_fi
lora_prompt = f"{reference}, {lora_prompt}"

return user_prompt, lora_prompt

def _inject_embedding_mentions_flux(self, text, embedding_trigger, lora_trigger_text):
pattern = r'(<{0}>|<{1}>|{0}|{1})'.format(
re.escape(embedding_trigger),
re.escape(embedding_trigger.lower())
)
text = re.sub(pattern, lora_trigger_text, text, flags=re.IGNORECASE)
text = re.sub(r'(<concept>)', lora_trigger_text, text, flags=re.IGNORECASE)

if lora_trigger_text not in text: # Make sure the concept is always triggered:
if not embedding_trigger: # Handles both None and empty string
text = re.sub(r'(<concept>)', lora_trigger_text, text, flags=re.IGNORECASE)
else:
pattern = r'(<{0}>|<{1}>|{0}|{1})'.format(
re.escape(embedding_trigger),
re.escape(embedding_trigger.lower())
)
text = re.sub(pattern, lora_trigger_text, text, flags=re.IGNORECASE)
text = re.sub(r'(<concept>)', lora_trigger_text, text, flags=re.IGNORECASE)

if lora_trigger_text not in text:
text = f"{lora_trigger_text}, {text}"

return text


def _transport_lora_flux(self, lora_url: str):
loras_folder = "/root/models/loras"
Expand Down Expand Up @@ -613,7 +629,7 @@ def _validate_comfyui_args(self, workflow, tool):
raise Exception(f"Remap parameter {key} is missing original choices: {choices}")

def _inject_args_into_workflow(self, workflow, tool, args):

base_model = "unknown"
# Helper function to validate and normalize URLs
def validate_url(url):
if not isinstance(url, str):
Expand All @@ -622,16 +638,18 @@ def validate_url(url):
url = 'https://' + url
return url

pprint(args)
print("===== Injecting comfyui args into workflow =====")
pprint(args)

embedding_trigger = None
lora_trigger_text = None
embedding_triggers = {"lora": None, "lora2": None}
lora_trigger_texts = {"lora": None, "lora2": None}

# download and transport files
for key, param in tool.model.model_fields.items():
metadata = param.json_schema_extra or {}
file_type = metadata.get('file_type')
is_array = metadata.get('is_array')
print(f"Parsing {key}, param: {param}")

if file_type and any(t in ["image", "video", "audio"] for t in file_type.split("|")):
if not args.get(key):
Expand All @@ -648,22 +666,21 @@ def validate_url(url):

elif file_type == "lora":
lora_id = args.get(key)
print("LORA ID", lora_id)

if not lora_id:
args[key] = None
args["lora_strength"] = 0
print("REMOVE LORA")
args[f"{key}_strength"] = 0
print(f"DISABLING {key}")
continue

print("LORA ID", lora_id)
print(type(lora_id))

print(f"Found {key} LORA ID: ", lora_id)

models = get_collection("models3")
lora = models.find_one({"_id": ObjectId(lora_id)})
print("found lora", lora)
#print("found lora:\n", lora)

if not lora:
raise Exception(f"Lora {lora_id} not found")
raise Exception(f"Lora {key} with id: {lora_id} not found!")

base_model = lora.get("base_model")
lora_url = lora.get("checkpoint")
Expand All @@ -683,18 +700,13 @@ def validate_url(url):
lora_filename, embeddings_filename, embedding_trigger, lora_mode = self._transport_lora_sdxl(lora_url)
elif base_model == "flux-dev":
lora_filename = self._transport_lora_flux(lora_url)
embedding_trigger = lora.get("args", {}).get("name")
lora_trigger_text = lora.get("lora_trigger_text")
embedding_triggers[key] = lora.get("args", {}).get("name")
try:
lora_trigger_texts[key] = lora.get("lora_trigger_text")
except: # old flux LoRA's:
lora_trigger_texts[key] = lora.get("args", {}).get("caption_prefix")

args[key] = lora_filename
args["use_lora"] = True
print("lora filename", lora_filename)

# inject args
# comfyui_map = {
# param.name: param.comfyui
# for param in tool_.parameters if param.comfyui
# }

for key, comfyui in tool.comfyui_map.items():

Expand All @@ -706,22 +718,29 @@ def validate_url(url):
continue

# if there's a lora, replace mentions with embedding name
if key == "prompt" and embedding_trigger:
lora_strength = args.get("lora_strength", 0.5)
if base_model == "flux-dev":
print("INJECTING LORA TRIGGER TEXT", lora_trigger_text)
value = self._inject_embedding_mentions_flux(value, embedding_trigger, lora_trigger_text)
print("INJECTED LORA TRIGGER TEXT", value)
if key == "prompt":
if "flux" in base_model:
for lora_key in ["lora", "lora2"]:
if args.get(f"use_{lora_key}", False):
lora_strength = args.get(f"{lora_key}_strength", 0.7)
value = self._inject_embedding_mentions_flux(
value,
embedding_triggers[lora_key],
lora_trigger_texts[lora_key]
)
print(f"====> INJECTED {lora_key} TRIGGER TEXT", value)
elif base_model == "sdxl":
no_token_prompt, value = self._inject_embedding_mentions_sdxl(value, embedding_trigger, embeddings_filename, lora_mode, lora_strength)

if "no_token_prompt" in args:
no_token_mapping = next((comfy_param for key, comfy_param in tool.comfyui_map.items() if key == "no_token_prompt"), None)
if no_token_mapping:
print("Updating no_token_prompt for SDXL: ", no_token_prompt)
workflow[str(no_token_mapping.node_id)][no_token_mapping.field][no_token_mapping.subfield] = no_token_prompt

print("prompt updated:", value)
if embedding_trigger:
lora_strength = args.get("lora_strength", 0.7)
no_token_prompt, value = self._inject_embedding_mentions_sdxl(value, embedding_trigger, embeddings_filename, lora_mode, lora_strength)

if "no_token_prompt" in args:
no_token_mapping = next((comfy_param for key, comfy_param in tool.comfyui_map.items() if key == "no_token_prompt"), None)
if no_token_mapping:
print("Updating no_token_prompt for SDXL: ", no_token_prompt)
workflow[str(no_token_mapping.node_id)][no_token_mapping.field][no_token_mapping.subfield] = no_token_prompt

print("====> Final updated prompt for workflow: ", value)

if comfyui.preprocessing is not None:
if comfyui.preprocessing == "csv":
Expand Down
75 changes: 35 additions & 40 deletions eve/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import yaml
import time
import json
import traceback
from pathlib import Path
from bson import ObjectId
from typing import Optional, Literal, Any, Dict, List
from typing import Optional, Literal, Any, Dict, List, ClassVar
from dotenv import dotenv_values
from pydantic import SecretStr, Field
from pydantic.json_schema import SkipJsonSchema
Expand All @@ -15,6 +16,8 @@
from .user import User, Manna
from .models import Model

CHECK_INTERVAL = 30 # how often to check cached agents for updates

default_presets_flux = {
"flux_dev_lora": {},
"runway": {},
Expand Down Expand Up @@ -44,24 +47,23 @@ class Agent(User):
name: str
description: str
instructions: str
# models: Optional[Dict[str, ObjectId]] = None
model: Optional[ObjectId] = None
test_args: Optional[List[Dict[str, Any]]] = None

tools: Optional[Dict[str, Dict]] = None
tools_cache: SkipJsonSchema[Optional[Dict[str, Tool]]] = Field(None, exclude=True)

last_check: ClassVar[Dict[str, float]] = {} # seconds

def __init__(self, **data):
if isinstance(data.get('owner'), str):
data['owner'] = ObjectId(data['owner'])
# if data.get('models'):
# data['models'] = {k: ObjectId(v) if isinstance(v, str) else v for k, v in data['models'].items()}
# Load environment variables into secrets dictionary
db = os.getenv("DB")
env_dir = Path(__file__).parent / "agents"
env_vars = dotenv_values(f"{str(env_dir)}/{data['username']}/.env")
data['secrets'] = {key: SecretStr(value) for key, value in env_vars.items()}
env_vars = dotenv_values(f"{str(env_dir)}/{db.lower()}/{data['username']}/.env")
data['secrets'] = {key: SecretStr(value) for key, value in env_vars.items()}
super().__init__(**data)

@classmethod
def convert_from_yaml(cls, schema: dict, file_path: str = None) -> dict:
"""
Expand Down Expand Up @@ -109,9 +111,11 @@ def from_yaml(cls, file_path, cache=False):
@classmethod
def from_mongo(cls, document_id, cache=False):
if cache:
if document_id not in _agent_cache:
_agent_cache[str(document_id)] = super().from_mongo(document_id)
return _agent_cache[str(document_id)]
id = str(document_id)
if id not in _agent_cache:
_agent_cache[id] = super().from_mongo(document_id)
cls._check_for_updates(id, document_id)
return _agent_cache[id]
else:
return super().from_mongo(document_id)

Expand All @@ -120,6 +124,7 @@ def load(cls, username, cache=False):
if cache:
if username not in _agent_cache:
_agent_cache[username] = super().load(username=username)
cls._check_for_updates(username, _agent_cache[username].id)
return _agent_cache[username]
else:
return super().load(username=username)
Expand Down Expand Up @@ -184,7 +189,7 @@ def _setup_tools(cls, schema: dict) -> dict:

return schema

def get_tools(self,cache=False):
def get_tools(self, cache=False):
if not hasattr(self, "tools") or not self.tools:
self.tools = {}

Expand All @@ -204,23 +209,18 @@ def get_tools(self,cache=False):
def get_tool(self, tool_name, cache=False):
return self.get_tools(cache=cache)[tool_name]

@classmethod
def _check_for_updates(cls, cache_key: str, agent_id: ObjectId):
"""Check if agent needs to be updated based on updatedAt field"""
current_time = time.time()
last_check = cls.last_check.get(cache_key, 0)

def get_agents_from_api_files(root_dir: str = None, agents: List[str] = None, include_inactive: bool = False) -> Dict[str, Agent]:
"""Get all agents inside a directory"""

api_files = get_api_files(root_dir, include_inactive)

all_agents = {
key: Agent.from_yaml(api_file)
for key, api_file in api_files.items()
}

if agents:
agents = {k: v for k, v in all_agents.items() if k in agents}
else:
agents = all_agents

return agents
if current_time - last_check >= CHECK_INTERVAL:
cls.last_check[cache_key] = current_time
collection = get_collection(cls.collection_name)
db_agent = collection.find_one({"_id": agent_id})
if db_agent and db_agent.get("updatedAt") != _agent_cache[cache_key].updatedAt:
_agent_cache[cache_key].reload()


def get_agents_from_mongo(agents: List[str] = None, include_inactive: bool = False) -> Dict[str, Agent]:
Expand All @@ -243,33 +243,28 @@ def get_agents_from_mongo(agents: List[str] = None, include_inactive: bool = Fal

return agents

def get_api_files(root_dir: str = None, include_inactive: bool = False) -> List[str]:

def get_api_files(root_dir: str = None) -> List[str]:
"""Get all agent directories inside a directory"""

env = os.getenv("DB")
db = os.getenv("DB").lower()

if root_dir:
root_dirs = [root_dir]
else:
eve_root = os.path.dirname(os.path.abspath(__file__))
root_dirs = [
os.path.join(eve_root, agents_dir)
for agents_dir in [f"agents/{env}"]
for agents_dir in [f"agents/{db}"]
]

api_files = {}
for root_dir in root_dirs:
for root, _, files in os.walk(root_dir):
if "api.yaml" in files and "test.json" in files:
api_file = os.path.join(root, "api.yaml")
with open(api_file, 'r') as f:
schema = yaml.safe_load(f)
if schema.get("status") == "inactive" and not include_inactive:
continue
key = schema.get("key", os.path.relpath(root).split("/")[-1])
if key in api_files:
raise ValueError(f"Duplicate agent {key} found.")
api_files[key] = os.path.join(os.path.relpath(root), "api.yaml")
api_path = os.path.join(root, "api.yaml")
key = os.path.relpath(root).split("/")[-1]
api_files[key] = api_path

return api_files

Expand Down
4 changes: 4 additions & 0 deletions eve/agents/prod/abraham/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ clients:
enabled: true
telegram:
enabled: true

deployments:
- discord
- telegram
2 changes: 2 additions & 0 deletions eve/agents/prod/eve/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ tools:
lora_trainer:
flux_trainer:
news:
websearch:
weather:
stable_audio:
musicgen:
audio_split_stems:
Expand Down
Loading
Loading