Skip to content

Commit

Permalink
Merge pull request #13 from edenartlab/stage
Browse files Browse the repository at this point in the history
refactor, stories, reliability
  • Loading branch information
genekogan authored Jan 16, 2024
2 parents e14c4fb + 826f89d commit 882d0b5
Show file tree
Hide file tree
Showing 28 changed files with 572 additions and 646 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@ Run

`rye run uvicorn app.server:app --reload`

Send a command to the server


curl -X POST 'http://localhost:5050/tasks/create' \
-H 'x-api-key: YOUR_API_KEY' \
-H 'x-api-secret: YOUR_API_SECRET' \
-H 'Content-Type: application/json' \
-d '{
"generatorName": "monologue",
"config": {
"characterId": "6577e5d5c77b37642c252423",
"prompt": "who are you?"
}
}'



Tests

`rye run pytest -s tests`
`rye run pytest -s tests`

3 changes: 3 additions & 0 deletions app/animations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .monologue import animated_monologue
from .dialogue import animated_dialogue
from .story import animated_story
48 changes: 48 additions & 0 deletions app/animations/animation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

from ..plugins import replicate, elevenlabs, s3
from ..character import EdenCharacter
from ..utils import combine_speech_video


def talking_head(
character: EdenCharacter,
text: str,
width: Optional[int] = None,
height: Optional[int] = None
) -> str:
audio_bytes = elevenlabs.tts(
text,
voice=character.voice
)
audio_url = s3.upload(audio_bytes, "mp3")
output_url, thumbnail_url = replicate.wav2lip(
face_url=character.image,
speech_url=audio_url,
gfpgan=False,
gfpgan_upscale=1,
width=width,
height=height,
)
return output_url, thumbnail_url


def screenplay_clip(
character: EdenCharacter,
speech: str,
image_text: str,
width: Optional[int] = None,
height: Optional[int] = None
) -> str:
audio_bytes = elevenlabs.tts(
speech,
voice=character.voice
)
audio_url = s3.upload(audio_bytes, "mp3")
video_url, thumbnail_url = replicate.txt2vid(
interpolation_texts=[image_text],
width=width,
height=height,
)
output_filename = combine_speech_video(audio_url, video_url)
return output_filename, thumbnail_url
67 changes: 67 additions & 0 deletions app/animations/dialogue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import requests
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed

from .. import utils
from .animation import talking_head
from ..plugins import replicate, elevenlabs, s3
from ..character import EdenCharacter
from ..scenarios import dialogue
from ..models import DialogueRequest

MAX_PIXELS = 1024 * 1024
MAX_WORKERS = 3


def animated_dialogue(request: DialogueRequest):
result = dialogue(request)
print(result)

characters = {
character_id: EdenCharacter(character_id)
for character_id in request.character_ids
}
images = [
characters[character_id].image
for character_id in request.character_ids
]
width, height = utils.calculate_target_dimensions(images, MAX_PIXELS)

def run_talking_head_segment(message):
character = characters[message["character_id"]]
output, _ = talking_head(
character,
message["message"],
width,
height
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
response = requests.get(output, stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=8192):
temp_file.write(chunk)
temp_file.flush()
return temp_file.name

video_files = utils.process_in_parallel(
result.dialogue,
run_talking_head_segment,
max_workers=MAX_WORKERS
)

# concatenate the final video clips
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_output_file:
utils.concatenate_videos(video_files, temp_output_file.name)
with open(temp_output_file.name, 'rb') as f:
video_bytes = f.read()
output_url = s3.upload(video_bytes, "mp4")
os.remove(temp_output_file.name)
for video_file in video_files:
os.remove(video_file)

# generate thumbnail
thumbnail = utils.create_dialogue_thumbnail(*images, width, height)
thumbnail_url = s3.upload(thumbnail, "webp")

return output_url, thumbnail_url
17 changes: 17 additions & 0 deletions app/animations/monologue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import requests

from .animation import talking_head
from ..plugins import replicate, elevenlabs, s3
from ..character import EdenCharacter
from ..scenarios import monologue
from ..models import MonologueRequest


def animated_monologue(request: MonologueRequest):
character = EdenCharacter(request.character_id)
result = monologue(request)
print(result)
output, thumbnail_url = talking_head(character, result.monologue)
output_bytes = requests.get(output).content
output_url = s3.upload(output_bytes, "mp4")
return output_url, thumbnail_url
71 changes: 71 additions & 0 deletions app/animations/story.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import requests
import tempfile

from .. import utils
from ..plugins import replicate, elevenlabs, s3
from ..character import EdenCharacter
from ..scenarios import story
from ..models import StoryRequest, StoryResult
from .animation import screenplay_clip

MAX_PIXELS = 1024 * 1024
MAX_WORKERS = 3


def animated_story(request: StoryRequest):
screenplay = story(request)
print(screenplay)

characters = {
character_id: EdenCharacter(character_id)
for character_id in request.character_ids + [request.narrator_id]
}

character_name_lookup = {
character.name: character_id
for character_id, character in characters.items()
}

images = [
characters[character_id].image
for character_id in request.character_ids
]

width, height = utils.calculate_target_dimensions(images, MAX_PIXELS)

def run_story_segment(clip):
if clip['voiceover'] == 'character':
character_id = character_name_lookup[clip['character']]
character = characters[character_id]
else:
character = characters[request.narrator_id]
output_filename, thumbnail_url = screenplay_clip(
character,
clip['speech'],
clip['image_description'],
width,
height
)
return output_filename, thumbnail_url

results = utils.process_in_parallel(
screenplay['clips'],
run_story_segment,
max_workers=MAX_WORKERS
)

video_files = [video_file for video_file, thumbnail in results]
thumbnail_url = results[0][1]

with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_output_file:
utils.concatenate_videos(video_files, temp_output_file.name)
with open(temp_output_file.name, 'rb') as f:
video_bytes = f.read()
output_url = s3.upload(video_bytes, "mp4")

# clean up clips
for video_file in video_files:
os.remove(video_file)

return output_url, thumbnail_url
3 changes: 2 additions & 1 deletion app/character.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from pydantic import Field, BaseModel, ValidationError

from .mongo import get_character_data
from .routers.tasks import summary, SummaryRequest
from .scenarios.tasks import summary
from .llm import LLM
from .llm.models import ChatMessage
from .models import SummaryRequest
from .prompt_templates.assistant import (
identity_template,
reply_template,
Expand Down
49 changes: 19 additions & 30 deletions app/routers/generator.py → app/generator.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
from typing import Optional, List
from fastapi import APIRouter, Request, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from fastapi import BackgroundTasks
import uuid
import traceback

import requests

from .dags import monologue_dag, dialogue_dag
from .story import cinema
from ..mongo import get_character_data
from ..llm import LLM
from ..prompt_templates import monologue_template, dialogue_template

from ..models import MonologueRequest, MonologueOutput
from ..models import DialogueRequest, DialogueOutput, CinemaRequest
from ..models import TaskRequest, TaskUpdate, TaskOutput

router = APIRouter()
from .animations import animated_monologue, animated_dialogue, animated_story
from .models import MonologueRequest, MonologueResult
from .models import DialogueRequest, DialogueResult, StoryRequest
from .models import TaskRequest, TaskUpdate, TaskResult


def process_task(task_id: str, request: TaskRequest, task_type: str):
def process_task(task_id: str, request: TaskRequest):
print("config", request.config)

task_type = request.generatorName
webhook_url = request.webhookUrl

update = TaskUpdate(
id=task_id,
output=TaskOutput(progress=0),
output=TaskResult(progress=0),
status="processing",
error=None,
)

requests.post(webhook_url, json=update.dict())
if webhook_url:
requests.post(webhook_url, json=update.dict())

try:
if task_type == "monologue":
Expand All @@ -42,7 +32,7 @@ def process_task(task_id: str, request: TaskRequest, task_type: str):
character_id=character_id,
prompt=prompt,
)
output_url, thumbnail_url = monologue_dag(task_req)
output_url, thumbnail_url = animated_monologue(task_req)

elif task_type == "dialogue":
character_ids = request.config.get("characterIds")
Expand All @@ -51,19 +41,18 @@ def process_task(task_id: str, request: TaskRequest, task_type: str):
character_ids=character_ids,
prompt=prompt,
)
output_url, thumbnail_url = dialogue_dag(task_req)
output_url, thumbnail_url = animated_dialogue(task_req)

elif task_type == "story":
character_ids = request.config.get("characterIds")
prompt = request.config.get("prompt")
task_req = CinemaRequest(
task_req = StoryRequest(
character_ids=character_ids,
prompt=prompt,
)
output_url = cinema(task_req)
thumbnail_url = "https://edenartlab-prod-data.s3.us-east-1.amazonaws.com/e745b8c200bb10efe744caa800c7c7f89c3ae05c39fa4aa0595bdd138117c592.png"
output_url, thumbnail_url = animated_story(task_req)

output = TaskOutput(
output = TaskResult(
files=[output_url],
thumbnails=[thumbnail_url],
name=prompt,
Expand All @@ -75,7 +64,7 @@ def process_task(task_id: str, request: TaskRequest, task_type: str):
error = None

except Exception as e:
output = TaskOutput(
output = TaskResult(
files=[],
thumbnails=[],
name=prompt,
Expand All @@ -94,12 +83,12 @@ def process_task(task_id: str, request: TaskRequest, task_type: str):
)
print("update", update.dict())

requests.post(webhook_url, json=update.dict())
if webhook_url:
requests.post(webhook_url, json=update.dict())


@router.post("/tasks/create")
async def generate_task(background_tasks: BackgroundTasks, request: TaskRequest):
task_id = str(uuid.uuid4())
if request.generatorName in ["monologue", "dialogue", "story"]:
background_tasks.add_task(process_task, task_id, request, request.generatorName)
background_tasks.add_task(process_task, task_id, request)
return {"id": task_id}
Loading

0 comments on commit 882d0b5

Please sign in to comment.