Skip to content

Commit

Permalink
Merge pull request #44 from edenartlab/staging
Browse files Browse the repository at this point in the history
Staging and prod
  • Loading branch information
genekogan authored Jan 5, 2025
2 parents d7bde28 + 5270ba8 commit 9e24089
Show file tree
Hide file tree
Showing 61 changed files with 876 additions and 895 deletions.
56 changes: 33 additions & 23 deletions comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from eve.tool import Tool
from eve.mongo import get_collection
from eve.task import task_handler_method
from eve.s3 import get_full_url

GPUs = {
"A100": modal.gpu.A100(),
Expand Down Expand Up @@ -264,13 +265,13 @@ def _start(self, port=8188):
t2 = time.time()
self.launch_time = t2 - t1

def _execute(self, workflow_name: str, args: dict, db: str):
def _execute(self, workflow_name: str, args: dict):
try:
tool_path = f"/root/workspace/workflows/{workflow_name}"
tool = Tool.from_yaml(f"{tool_path}/api.yaml")
workflow = json.load(open(f"{tool_path}/workflow_api.json", 'r'))
self._validate_comfyui_args(workflow, tool)
workflow = self._inject_args_into_workflow(workflow, tool, args, db=db)
workflow = self._inject_args_into_workflow(workflow, tool, args)
prompt_id = self._queue_prompt(workflow)['prompt_id']
outputs = self._get_outputs(prompt_id)
output = outputs[str(tool.comfyui_output_node_id)]
Expand All @@ -290,14 +291,14 @@ def _execute(self, workflow_name: str, args: dict, db: str):
raise

@modal.method()
def run(self, tool_key: str, args: dict, db: str):
result = self._execute(tool_key, args, db=db)
return eden_utils.upload_result(result, db=db)
def run(self, tool_key: str, args: dict):
result = self._execute(tool_key, args)
return eden_utils.upload_result(result)

@modal.method()
@task_handler_method
async def run_task(self, tool_key: str, args: dict, db: str):
return self._execute(tool_key, args, db=db)
async def run_task(self, tool_key: str, args: dict):
return self._execute(tool_key, args)

@modal.enter()
def enter(self):
Expand Down Expand Up @@ -348,8 +349,8 @@ def test_workflows(self):
test_name = f"{workflow}_{os.path.basename(test)}"
print(f"Running test: {test_name}")
t1 = time.time()
result = self._execute(workflow, test_args, db="STAGE")
result = eden_utils.upload_result(result, db="STAGE")
result = self._execute(workflow, test_args)
result = eden_utils.upload_result(result)
t2 = time.time()
results[test_name] = result
results["_performance"][test_name] = t2 - t1
Expand Down Expand Up @@ -460,16 +461,16 @@ def _inject_embedding_mentions_sdxl(self, text, embedding_trigger, embeddings_fi

return user_prompt, lora_prompt

def _inject_embedding_mentions_flux(self, text, embedding_trigger, caption_prefix):
def _inject_embedding_mentions_flux(self, text, embedding_trigger, lora_trigger_text):
pattern = r'(<{0}>|<{1}>|{0}|{1})'.format(
re.escape(embedding_trigger),
re.escape(embedding_trigger.lower())
)
text = re.sub(pattern, caption_prefix, text, flags=re.IGNORECASE)
text = re.sub(r'(<concept>)', caption_prefix, text, flags=re.IGNORECASE)
text = re.sub(pattern, lora_trigger_text, text, flags=re.IGNORECASE)
text = re.sub(r'(<concept>)', lora_trigger_text, text, flags=re.IGNORECASE)

if caption_prefix not in text: # Make sure the concept is always triggered:
text = f"{caption_prefix}, {text}"
if lora_trigger_text not in text: # Make sure the concept is always triggered:
text = f"{lora_trigger_text}, {text}"

return text

Expand Down Expand Up @@ -611,7 +612,7 @@ def _validate_comfyui_args(self, workflow, tool):
if not all(choice in remap.map.keys() for choice in choices):
raise Exception(f"Remap parameter {key} is missing original choices: {choices}")

def _inject_args_into_workflow(self, workflow, tool, args, db="STAGE"):
def _inject_args_into_workflow(self, workflow, tool, args):

# Helper function to validate and normalize URLs
def validate_url(url):
Expand All @@ -624,7 +625,7 @@ def validate_url(url):
pprint(args)

embedding_trigger = None
caption_prefix = None
lora_trigger_text = None

# download and transport files
for key, param in tool.model.model_fields.items():
Expand Down Expand Up @@ -653,14 +654,18 @@ def validate_url(url):
args["lora_strength"] = 0
print("REMOVE LORA")
continue

print("LORA ID", lora_id)
print(type(lora_id))

models = get_collection("models", db=db)
models = get_collection("models3")
lora = models.find_one({"_id": ObjectId(lora_id)})
base_model = lora.get("base_model")
print("LORA", lora)
print("found lora", lora)

if not lora:
raise Exception(f"Lora {lora_id} not found")

base_model = lora.get("base_model")
lora_url = lora.get("checkpoint")
#lora_name = lora.get("name")
#pretrained_model = lora.get("args").get("sd_model_version")
Expand All @@ -670,17 +675,20 @@ def validate_url(url):
else:
print("LORA URL", lora_url)

lora_url = get_full_url(lora_url)
print("lora url", lora_url)
print("base model", base_model)

if base_model == "sdxl":
lora_filename, embeddings_filename, embedding_trigger, lora_mode = self._transport_lora_sdxl(lora_url)
elif base_model == "flux-dev":
lora_filename = self._transport_lora_flux(lora_url)
embedding_trigger = lora.get("args", {}).get("name")
caption_prefix = lora.get("args", {}).get("caption_prefix")
lora_trigger_text = lora.get("lora_trigger_text")

args[key] = lora_filename
print("lora filename", lora_filename)
args["use_lora"] = True
print("lora filename", lora_filename)

# inject args
# comfyui_map = {
Expand All @@ -700,8 +708,10 @@ def validate_url(url):
# if there's a lora, replace mentions with embedding name
if key == "prompt" and embedding_trigger:
lora_strength = args.get("lora_strength", 0.5)
if base_model == "flux_dev":
value = self._inject_embedding_mentions_flux(value, embedding_trigger, caption_prefix)
if base_model == "flux-dev":
print("INJECTING LORA TRIGGER TEXT", lora_trigger_text)
value = self._inject_embedding_mentions_flux(value, embedding_trigger, lora_trigger_text)
print("INJECTED LORA TRIGGER TEXT", value)
elif base_model == "sdxl":
no_token_prompt, value = self._inject_embedding_mentions_sdxl(value, embedding_trigger, embeddings_filename, lora_mode, lora_strength)

Expand Down
73 changes: 54 additions & 19 deletions eve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,63 @@
import os

home_dir = str(Path.home())
eve_path = os.path.join(home_dir, ".eve")
env_path = ".env"

# First try ENV_PATH from environment
env_path_override = os.getenv("ENV_PATH")
if env_path_override and os.path.exists(env_path_override):
load_dotenv(env_path_override)

# Then try ~/.eve
if os.path.exists(eve_path):
load_dotenv(eve_path, override=True)
EDEN_API_KEY = None

# Finally fall back to .env
if os.path.exists(env_path):
load_dotenv(env_path, override=True)

# start sentry
sentry_dsn = os.getenv("SENTRY_DSN")
sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0)
def load_env(db):
global EDEN_API_KEY

# load api keys
EDEN_API_KEY = SecretStr(os.getenv("EDEN_API_KEY", ""))
db = db.upper()
if db not in ["STAGE", "PROD"]:
raise ValueError(f"Invalid database: {db}")

os.environ["DB"] = db

if not EDEN_API_KEY:
print("WARNING: EDEN_API_KEY is not set")
# First try ~/.eve
stage = db == "STAGE"
env_file = ".env.STAGE" if stage else ".env"
eve_file = ".eve.STAGE" if stage else ".eve"
eve_path = os.path.join(home_dir, eve_file)

if os.path.exists(eve_path):
load_dotenv(eve_path, override=True)

# Then try ENV_PATH from environment or .env
env_path_override = os.getenv("ENV_PATH")
if env_path_override and os.path.exists(env_path_override):
load_dotenv(env_path_override, override=True)
elif os.path.exists(env_file):
load_dotenv(env_file, override=True)

# start sentry
sentry_dsn = os.getenv("SENTRY_DSN")
sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0)

# load api keys
EDEN_API_KEY = SecretStr(os.getenv("EDEN_API_KEY", ""))

if not EDEN_API_KEY:
print("WARNING: EDEN_API_KEY is not set")

verify_env()


def verify_env():
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION_NAME = os.getenv("AWS_REGION_NAME")
MONGO_URI = os.getenv("MONGO_URI")

if not all([AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME]):
print(
"WARNING: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_REGION_NAME must be set in the environment"
)

if not MONGO_URI:
print("WARNING: MONGO_URI must be set in the environment")


db = os.getenv("DB", "STAGE")
load_env(db)
Loading

0 comments on commit 9e24089

Please sign in to comment.