diff --git a/eve/api/api.py b/eve/api/api.py index 6b509a9..ecb0abb 100644 --- a/eve/api/api.py +++ b/eve/api/api.py @@ -18,6 +18,8 @@ from eve.postprocessing import ( generate_lora_thumbnails, cancel_stuck_tasks, + download_nsfw_models, + run_nsfw_detection ) from eve.api.handlers import ( handle_create, @@ -202,11 +204,18 @@ async def trigger_delete( .env({"DB": db, "MODAL_SERVE": os.getenv("MODAL_SERVE")}) .apt_install("git", "libmagic1", "ffmpeg", "wget") .pip_install_from_pyproject(str(root_dir / "pyproject.toml")) + .pip_install( + "numpy<2.0", + "torch==2.0.1", + "torchvision", + "transformers", + "Pillow" + ) .run_commands(["playwright install"]) + .run_function(download_nsfw_models) .copy_local_dir(str(workflows_dir), "/workflows") ) - @app.function( image=image, keep_warm=1, @@ -234,8 +243,12 @@ async def postprocessing(): except Exception as e: print(f"Error cancelling stuck tasks: {e}") + try: + await run_nsfw_detection() + except Exception as e: + print(f"Error running nsfw detection: {e}") + try: await generate_lora_thumbnails() except Exception as e: print(f"Error generating lora thumbnails: {e}") - diff --git a/eve/postprocessing.py b/eve/postprocessing.py index 43cc58e..72f0a47 100644 --- a/eve/postprocessing.py +++ b/eve/postprocessing.py @@ -46,12 +46,60 @@ def cancel_stuck_tasks(): tool.cancel(task, force=True) except Exception as e: - print("Error canceling task", e) + print(f"Error canceling task {str(task.id)} {task.tool}", e) task.update(status="failed", error="Tool not found") sentry_sdk.capture_exception(e) traceback.print_exc() +def download_nsfw_models(): + from transformers import AutoModelForImageClassification, ViTImageProcessor + AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection") + ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection") + + +async def run_nsfw_detection(): + import torch + from PIL import Image + from transformers import AutoModelForImageClassification, ViTImageProcessor + + model = AutoModelForImageClassification.from_pretrained( + "Falconsai/nsfw_image_detection", + cache_dir="model-cache", + ) + processor = ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection") + + image_paths = [ + "https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/62946527441201f82e0e3d667fda480e176e9940a2e04f4e54c5230665dfc6f6.jpg", + "https://edenartlab-prod-data.s3.us-east-1.amazonaws.com/bb88e857586a358ce3f02f92911588207fbddeabff62a3d6a479517a646f053c.jpg" + ] + + images = [ + Image.open(eden_utils.download_file(url, f"{i}.jpg")).convert('RGB') + for i, url in enumerate(image_paths) + ] + + with torch.no_grad(): + inputs = processor(images=images, return_tensors="pt") + outputs = model(**inputs) + logits = outputs.logits + + print(logits) + predicted_labels = logits.argmax(-1).tolist() + print(predicted_labels) + output = [model.config.id2label[predicted_label] for predicted_label in predicted_labels] + print(output) + + # Sort image paths based on safe logit values + first_logits = logits[:, 0].tolist() + sorted_pairs = sorted(zip(image_paths, first_logits), key=lambda x: x[1], reverse=True) + sorted_image_paths = [pair[0] for pair in sorted_pairs] + + print("\nImage paths sorted by first logit value (highest to lowest):") + for i, path in enumerate(sorted_image_paths): + print(f"{i+1}. {path} (logit: {sorted_pairs[i][1]:.4f})") + + async def generate_lora_thumbnails(): tasks = get_collection(Task.collection_name) models = get_collection(Model.collection_name)