Skip to content

Commit

Permalink
nsfw detect
Browse files Browse the repository at this point in the history
  • Loading branch information
genekogan committed Jan 8, 2025
1 parent c6c9548 commit f6b619c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
17 changes: 15 additions & 2 deletions eve/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from eve.postprocessing import (
generate_lora_thumbnails,
cancel_stuck_tasks,
download_nsfw_models,
run_nsfw_detection
)
from eve.api.handlers import (
handle_create,
Expand Down Expand Up @@ -202,11 +204,18 @@ async def trigger_delete(
.env({"DB": db, "MODAL_SERVE": os.getenv("MODAL_SERVE")})
.apt_install("git", "libmagic1", "ffmpeg", "wget")
.pip_install_from_pyproject(str(root_dir / "pyproject.toml"))
.pip_install(
"numpy<2.0",
"torch==2.0.1",
"torchvision",
"transformers",
"Pillow"
)
.run_commands(["playwright install"])
.run_function(download_nsfw_models)
.copy_local_dir(str(workflows_dir), "/workflows")
)


@app.function(
image=image,
keep_warm=1,
Expand Down Expand Up @@ -234,8 +243,12 @@ async def postprocessing():
except Exception as e:
print(f"Error cancelling stuck tasks: {e}")

try:
await run_nsfw_detection()
except Exception as e:
print(f"Error running nsfw detection: {e}")

try:
await generate_lora_thumbnails()
except Exception as e:
print(f"Error generating lora thumbnails: {e}")

50 changes: 49 additions & 1 deletion eve/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,60 @@ def cancel_stuck_tasks():
tool.cancel(task, force=True)

except Exception as e:
print("Error canceling task", e)
print(f"Error canceling task {str(task.id)} {task.tool}", e)
task.update(status="failed", error="Tool not found")
sentry_sdk.capture_exception(e)
traceback.print_exc()


def download_nsfw_models():
from transformers import AutoModelForImageClassification, ViTImageProcessor
AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection")
ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection")


async def run_nsfw_detection():
import torch
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor

model = AutoModelForImageClassification.from_pretrained(
"Falconsai/nsfw_image_detection",
cache_dir="model-cache",
)
processor = ViTImageProcessor.from_pretrained("Falconsai/nsfw_image_detection")

image_paths = [
"https://edenartlab-stage-data.s3.us-east-1.amazonaws.com/62946527441201f82e0e3d667fda480e176e9940a2e04f4e54c5230665dfc6f6.jpg",
"https://edenartlab-prod-data.s3.us-east-1.amazonaws.com/bb88e857586a358ce3f02f92911588207fbddeabff62a3d6a479517a646f053c.jpg"
]

images = [
Image.open(eden_utils.download_file(url, f"{i}.jpg")).convert('RGB')
for i, url in enumerate(image_paths)
]

with torch.no_grad():
inputs = processor(images=images, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits

print(logits)
predicted_labels = logits.argmax(-1).tolist()
print(predicted_labels)
output = [model.config.id2label[predicted_label] for predicted_label in predicted_labels]
print(output)

# Sort image paths based on safe logit values
first_logits = logits[:, 0].tolist()
sorted_pairs = sorted(zip(image_paths, first_logits), key=lambda x: x[1], reverse=True)
sorted_image_paths = [pair[0] for pair in sorted_pairs]

print("\nImage paths sorted by first logit value (highest to lowest):")
for i, path in enumerate(sorted_image_paths):
print(f"{i+1}. {path} (logit: {sorted_pairs[i][1]:.4f})")


async def generate_lora_thumbnails():
tasks = get_collection(Task.collection_name)
models = get_collection(Model.collection_name)
Expand Down

0 comments on commit f6b619c

Please sign in to comment.