Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor and add stories #12

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,24 @@ Run

`rye run uvicorn app.server:app --reload`

Send a command to the server


curl -X POST 'http://localhost:5050/tasks/create' \
-H 'x-api-key: YOUR_API_KEY' \
-H 'x-api-secret: YOUR_API_SECRET' \
-H 'Content-Type: application/json' \
-d '{
"generatorName": "monologue",
"config": {
"characterId": "6577e5d5c77b37642c252423",
"prompt": "who are you?"
}
}'



Tests

`rye run pytest -s tests`
`rye run pytest -s tests`

3 changes: 3 additions & 0 deletions app/animations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .monologue import animated_monologue
from .dialogue import animated_dialogue
from .story import animated_story
48 changes: 48 additions & 0 deletions app/animations/animation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

from ..plugins import replicate, elevenlabs, s3
from ..character import EdenCharacter
from ..utils import combine_speech_video


def talking_head(
character: EdenCharacter,
text: str,
width: Optional[int] = None,
height: Optional[int] = None
) -> str:
audio_bytes = elevenlabs.tts(
text,
voice=character.voice
)
audio_url = s3.upload(audio_bytes, "mp3")
output_url, thumbnail_url = replicate.wav2lip(
face_url=character.image,
speech_url=audio_url,
gfpgan=False,
gfpgan_upscale=1,
width=width,
height=height,
)
return output_url, thumbnail_url


def screenplay_clip(
character: EdenCharacter,
speech: str,
image_text: str,
width: Optional[int] = None,
height: Optional[int] = None
) -> str:
audio_bytes = elevenlabs.tts(
speech,
voice=character.voice
)
audio_url = s3.upload(audio_bytes, "mp3")
video_url, thumbnail_url = replicate.txt2vid(
interpolation_texts=[image_text],
width=width,
height=height,
)
output_filename = combine_speech_video(audio_url, video_url)
return output_filename, thumbnail_url
65 changes: 65 additions & 0 deletions app/animations/dialogue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import requests
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed

from .. import utils
from .animation import talking_head
from ..plugins import replicate, elevenlabs, s3
from ..character import EdenCharacter
from ..scenarios import dialogue
from ..models import DialogueRequest

MAX_PIXELS = 1024 * 1024
MAX_WORKERS = 3


def animated_dialogue(request: DialogueRequest):
result = dialogue(request)
characters = {
character_id: EdenCharacter(character_id)
for character_id in request.character_ids
}
images = [
characters[character_id].image
for character_id in request.character_ids
]
width, height = utils.calculate_target_dimensions(images, MAX_PIXELS)

def run_talking_head_segment(message):
character = characters[message["character_id"]]
output, _ = talking_head(
character,
message["message"],
width,
height
)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
response = requests.get(output, stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=8192):
temp_file.write(chunk)
temp_file.flush()
return temp_file.name

video_files = utils.process_in_parallel(
result.dialogue,
run_talking_head_segment,
max_workers=MAX_WORKERS
)

# concatenate the final video clips
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_output_file:
utils.concatenate_videos(video_files, temp_output_file.name)
with open(temp_output_file.name, 'rb') as f:
video_bytes = f.read()
output_url = s3.upload(video_bytes, "mp4")
os.remove(temp_output_file.name)
for video_file in video_files:
os.remove(video_file)

# generate thumbnail
thumbnail = utils.create_dialogue_thumbnail(*images, width, height)
thumbnail_url = s3.upload(thumbnail, "webp")

return output_url, thumbnail_url
17 changes: 17 additions & 0 deletions app/animations/monologue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import requests

from .animation import talking_head
from ..plugins import replicate, elevenlabs, s3
from ..character import EdenCharacter
from ..scenarios import monologue
from ..models import MonologueRequest


def animated_monologue(request: MonologueRequest):
character = EdenCharacter(request.character_id)
#thumbnail_url = character.image
result = monologue(request)
output, thumbnail_url = talking_head(character, result.monologue)
output_bytes = requests.get(output).content
output_url = s3.upload(output_bytes, "mp4")
return output_url, thumbnail_url
70 changes: 70 additions & 0 deletions app/animations/story.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import requests
import tempfile

from .. import utils
from ..plugins import replicate, elevenlabs, s3
from ..character import EdenCharacter
from ..scenarios import story
from ..models import StoryRequest, StoryResult
from .animation import screenplay_clip

MAX_PIXELS = 1024 * 1024
MAX_WORKERS = 3


def animated_story(request: StoryRequest):
screenplay = story(request)

characters = {
character_id: EdenCharacter(character_id)
for character_id in request.character_ids + [request.narrator_id]
}

character_name_lookup = {
character.name: character_id
for character_id, character in characters.items()
}

images = [
characters[character_id].image
for character_id in request.character_ids
]

width, height = utils.calculate_target_dimensions(images, MAX_PIXELS)

def run_story_segment(clip):
if clip['voiceover'] == 'character':
character_id = character_name_lookup[clip['character']]
character = characters[character_id]
else:
character = characters[request.narrator_id]
output_filename, thumbnail_url = screenplay_clip(
character,
clip['speech'],
clip['image_description'],
width,
height
)
return output_filename, thumbnail_url

results = utils.process_in_parallel(
screenplay['clips'],
run_story_segment,
max_workers=MAX_WORKERS
)

video_files = [video_file for video_file, thumbnail in results]
thumbnail_url = results[0][1]

with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_output_file:
utils.concatenate_videos(video_files, temp_output_file.name)
with open(temp_output_file.name, 'rb') as f:
video_bytes = f.read()
output_url = s3.upload(video_bytes, "mp4")

# clean up clips
for video_file in video_files:
os.remove(video_file)

return output_url, thumbnail_url
3 changes: 2 additions & 1 deletion app/character.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from pydantic import Field, BaseModel, ValidationError

from .mongo import get_character_data
from .routers.tasks import summary, SummaryRequest
from .scenarios.tasks import summary
from .llm import LLM
from .llm.models import ChatMessage
from .models import SummaryRequest
from .prompt_templates.assistant import (
identity_template,
reply_template,
Expand Down
49 changes: 19 additions & 30 deletions app/routers/generator.py → app/generator.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
from typing import Optional, List
from fastapi import APIRouter, Request, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from fastapi import BackgroundTasks
import uuid
import traceback

import requests

from .dags import monologue_dag, dialogue_dag
from .story import cinema
from ..mongo import get_character_data
from ..llm import LLM
from ..prompt_templates import monologue_template, dialogue_template

from ..models import MonologueRequest, MonologueOutput
from ..models import DialogueRequest, DialogueOutput, CinemaRequest
from ..models import TaskRequest, TaskUpdate, TaskOutput

router = APIRouter()
from .animations import animated_monologue, animated_dialogue, animated_story
from .models import MonologueRequest, MonologueResult
from .models import DialogueRequest, DialogueResult, StoryRequest
from .models import TaskRequest, TaskUpdate, TaskResult


def process_task(task_id: str, request: TaskRequest, task_type: str):
def process_task(task_id: str, request: TaskRequest):
print("config", request.config)

task_type = request.generatorName
webhook_url = request.webhookUrl

update = TaskUpdate(
id=task_id,
output=TaskOutput(progress=0),
output=TaskResult(progress=0),
status="processing",
error=None,
)

requests.post(webhook_url, json=update.dict())
if webhook_url:
requests.post(webhook_url, json=update.dict())

try:
if task_type == "monologue":
Expand All @@ -42,7 +32,7 @@ def process_task(task_id: str, request: TaskRequest, task_type: str):
character_id=character_id,
prompt=prompt,
)
output_url, thumbnail_url = monologue_dag(task_req)
output_url, thumbnail_url = animated_monologue(task_req)

elif task_type == "dialogue":
character_ids = request.config.get("characterIds")
Expand All @@ -51,19 +41,18 @@ def process_task(task_id: str, request: TaskRequest, task_type: str):
character_ids=character_ids,
prompt=prompt,
)
output_url, thumbnail_url = dialogue_dag(task_req)
output_url, thumbnail_url = animated_dialogue(task_req)

elif task_type == "story":
character_ids = request.config.get("characterIds")
prompt = request.config.get("prompt")
task_req = CinemaRequest(
task_req = StoryRequest(
character_ids=character_ids,
prompt=prompt,
)
output_url = cinema(task_req)
thumbnail_url = "https://edenartlab-prod-data.s3.us-east-1.amazonaws.com/e745b8c200bb10efe744caa800c7c7f89c3ae05c39fa4aa0595bdd138117c592.png"
output_url, thumbnail_url = animated_story(task_req)

output = TaskOutput(
output = TaskResult(
files=[output_url],
thumbnails=[thumbnail_url],
name=prompt,
Expand All @@ -75,7 +64,7 @@ def process_task(task_id: str, request: TaskRequest, task_type: str):
error = None

except Exception as e:
output = TaskOutput(
output = TaskResult(
files=[],
thumbnails=[],
name=prompt,
Expand All @@ -94,12 +83,12 @@ def process_task(task_id: str, request: TaskRequest, task_type: str):
)
print("update", update.dict())

requests.post(webhook_url, json=update.dict())
if webhook_url:
requests.post(webhook_url, json=update.dict())


@router.post("/tasks/create")
async def generate_task(background_tasks: BackgroundTasks, request: TaskRequest):
task_id = str(uuid.uuid4())
if request.generatorName in ["monologue", "dialogue", "story"]:
background_tasks.add_task(process_task, task_id, request, request.generatorName)
background_tasks.add_task(process_task, task_id, request)
return {"id": task_id}
Loading
Loading