Skip to content

Commit

Permalink
Merge branch 'staging' of https://github.com/edenartlab/eve into staging
Browse files Browse the repository at this point in the history
  • Loading branch information
genekogan committed Jan 5, 2025
2 parents 89dd8dc + fcd3ac5 commit 8fa9e97
Showing 1 changed file with 66 additions and 48 deletions.
114 changes: 66 additions & 48 deletions comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@
test_all = True if os.getenv("TEST_ALL") else False
skip_tests = os.getenv("SKIP_TESTS")

print("========================================")
print(f"db: {db}")
print(f"workspace: {workspace_name}")
print(f"test_workflows: {test_workflows}")
print(f"test_all: {test_all}")
print(f"skip_tests: {skip_tests}")
print("========================================")

def install_comfyui():
snapshot = json.load(open("/root/workspace/snapshot.json", 'r'))
comfyui_commit_sha = snapshot["comfyui"]
Expand Down Expand Up @@ -218,6 +226,7 @@ def download_files(force_redownload=False):
.env({"TEST_ALL": os.getenv("TEST_ALL")})
.apt_install("git", "git-lfs", "libgl1-mesa-glx", "libglib2.0-0", "libmagic1", "ffmpeg", "libegl1")
.pip_install_from_pyproject(str(root_dir / "pyproject.toml"))
.pip_install("diffusers==0.31.0")
.env({"WORKSPACE": workspace_name})
.copy_local_file(f"{root_workflows_folder}/workspaces/{workspace_name}/snapshot.json", "/root/workspace/snapshot.json")
.copy_local_file(f"{root_workflows_folder}/workspaces/{workspace_name}/downloads.json", "/root/workspace/downloads.json")
Expand Down Expand Up @@ -328,6 +337,9 @@ def test_workflows(self):
if not all([w in workflow_names for w in test_workflows]):
raise Exception(f"One or more invalid workflows found: {', '.join(test_workflows)}")
workflow_names = test_workflows
print(f"====> Running tests for subset of workflows: {' | '.join(workflow_names)}")
else:
print(f"====> Running tests for all workflows: {' | '.join(workflow_names)}")

if not workflow_names:
raise Exception("No workflows found!")
Expand All @@ -338,7 +350,7 @@ def test_workflows(self):
tests = glob.glob(f"/root/workspace/workflows/{workflow}/test*.json")
else:
tests = [f"/root/workspace/workflows/{workflow}/test.json"]
print("Running tests: ", tests)
print(f"====> Running tests for {workflow}: ", tests)
for test in tests:
tool = Tool.from_yaml(f"/root/workspace/workflows/{workflow}/api.yaml")
if tool.status == "inactive":
Expand All @@ -347,11 +359,12 @@ def test_workflows(self):
test_args = json.loads(open(test, "r").read())
test_args = tool.prepare_args(test_args)
test_name = f"{workflow}_{os.path.basename(test)}"
print(f"Running test: {test_name}")
print(f"====> Running test: {test_name}")
t1 = time.time()
result = self._execute(workflow, test_args)
result = eden_utils.upload_result(result)
result = eden_utils.prepare_result(result)
print(f"====> Final media url: {result}")
t2 = time.time()
results[test_name] = result
results["_performance"][test_name] = t2 - t1
Expand Down Expand Up @@ -461,20 +474,22 @@ def _inject_embedding_mentions_sdxl(self, text, embedding_trigger, embeddings_fi
lora_prompt = f"{reference}, {lora_prompt}"

return user_prompt, lora_prompt

def _inject_embedding_mentions_flux(self, text, embedding_trigger, lora_trigger_text):
pattern = r'(<{0}>|<{1}>|{0}|{1})'.format(
re.escape(embedding_trigger),
re.escape(embedding_trigger.lower())
)
text = re.sub(pattern, lora_trigger_text, text, flags=re.IGNORECASE)
text = re.sub(r'(<concept>)', lora_trigger_text, text, flags=re.IGNORECASE)

if lora_trigger_text not in text: # Make sure the concept is always triggered:
if not embedding_trigger: # Handles both None and empty string
text = re.sub(r'(<concept>)', lora_trigger_text, text, flags=re.IGNORECASE)
else:
pattern = r'(<{0}>|<{1}>|{0}|{1})'.format(
re.escape(embedding_trigger),
re.escape(embedding_trigger.lower())
)
text = re.sub(pattern, lora_trigger_text, text, flags=re.IGNORECASE)
text = re.sub(r'(<concept>)', lora_trigger_text, text, flags=re.IGNORECASE)

if lora_trigger_text not in text:
text = f"{lora_trigger_text}, {text}"

return text


def _transport_lora_flux(self, lora_url: str):
loras_folder = "/root/models/loras"
Expand Down Expand Up @@ -623,16 +638,18 @@ def validate_url(url):
url = 'https://' + url
return url

pprint(args)
print("===== Injecting comfyui args into workflow =====")
pprint(args)

embedding_trigger = None
lora_trigger_text = None
embedding_triggers = {"lora": None, "lora2": None}
lora_trigger_texts = {"lora": None, "lora2": None}

# download and transport files
for key, param in tool.model.model_fields.items():
metadata = param.json_schema_extra or {}
file_type = metadata.get('file_type')
is_array = metadata.get('is_array')
print(f"Parsing {key}, param: {param}")

if file_type and any(t in ["image", "video", "audio"] for t in file_type.split("|")):
if not args.get(key):
Expand All @@ -649,22 +666,21 @@ def validate_url(url):

elif file_type == "lora":
lora_id = args.get(key)
print("LORA ID", lora_id)

if not lora_id:
args[key] = None
args["lora_strength"] = 0
print("REMOVE LORA")
args[f"{key}_strength"] = 0
print(f"DISABLING {key}")
continue

print("LORA ID", lora_id)
print(type(lora_id))

print(f"Found {key} LORA ID: ", lora_id)

models = get_collection("models3")
lora = models.find_one({"_id": ObjectId(lora_id)})
print("found lora", lora)
#print("found lora:\n", lora)

if not lora:
raise Exception(f"Lora {lora_id} not found")
raise Exception(f"Lora {key} with id: {lora_id} not found!")

base_model = lora.get("base_model")
lora_url = lora.get("checkpoint")
Expand All @@ -684,18 +700,13 @@ def validate_url(url):
lora_filename, embeddings_filename, embedding_trigger, lora_mode = self._transport_lora_sdxl(lora_url)
elif base_model == "flux-dev":
lora_filename = self._transport_lora_flux(lora_url)
embedding_trigger = lora.get("args", {}).get("name")
lora_trigger_text = lora.get("lora_trigger_text")
embedding_triggers[key] = lora.get("args", {}).get("name")
try:
lora_trigger_texts[key] = lora.get("lora_trigger_text")
except: # old flux LoRA's:
lora_trigger_texts[key] = lora.get("args", {}).get("caption_prefix")

args[key] = lora_filename
args["use_lora"] = True
print("lora filename", lora_filename)

# inject args
# comfyui_map = {
# param.name: param.comfyui
# for param in tool_.parameters if param.comfyui
# }

for key, comfyui in tool.comfyui_map.items():

Expand All @@ -707,22 +718,29 @@ def validate_url(url):
continue

# if there's a lora, replace mentions with embedding name
if key == "prompt" and embedding_trigger:
lora_strength = args.get("lora_strength", 0.5)
if base_model == "flux-dev":
print("INJECTING LORA TRIGGER TEXT", lora_trigger_text)
value = self._inject_embedding_mentions_flux(value, embedding_trigger, lora_trigger_text)
print("INJECTED LORA TRIGGER TEXT", value)
if key == "prompt":
if "flux" in base_model:
for lora_key in ["lora", "lora2"]:
if args.get(f"use_{lora_key}", False):
lora_strength = args.get(f"{lora_key}_strength", 0.7)
value = self._inject_embedding_mentions_flux(
value,
embedding_triggers[lora_key],
lora_trigger_texts[lora_key]
)
print(f"====> INJECTED {lora_key} TRIGGER TEXT", value)
elif base_model == "sdxl":
no_token_prompt, value = self._inject_embedding_mentions_sdxl(value, embedding_trigger, embeddings_filename, lora_mode, lora_strength)

if "no_token_prompt" in args:
no_token_mapping = next((comfy_param for key, comfy_param in tool.comfyui_map.items() if key == "no_token_prompt"), None)
if no_token_mapping:
print("Updating no_token_prompt for SDXL: ", no_token_prompt)
workflow[str(no_token_mapping.node_id)][no_token_mapping.field][no_token_mapping.subfield] = no_token_prompt

print("prompt updated:", value)
if embedding_trigger:
lora_strength = args.get("lora_strength", 0.7)
no_token_prompt, value = self._inject_embedding_mentions_sdxl(value, embedding_trigger, embeddings_filename, lora_mode, lora_strength)

if "no_token_prompt" in args:
no_token_mapping = next((comfy_param for key, comfy_param in tool.comfyui_map.items() if key == "no_token_prompt"), None)
if no_token_mapping:
print("Updating no_token_prompt for SDXL: ", no_token_prompt)
workflow[str(no_token_mapping.node_id)][no_token_mapping.field][no_token_mapping.subfield] = no_token_prompt

print("====> Final updated prompt for workflow: ", value)

if comfyui.preprocessing is not None:
if comfyui.preprocessing == "csv":
Expand Down

0 comments on commit 8fa9e97

Please sign in to comment.