Skip to content

Commit

Permalink
story validation hack, gpt-4o
Browse files Browse the repository at this point in the history
  • Loading branch information
genekogan committed Jun 18, 2024
1 parent fe092ec commit 8e642e9
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 15 deletions.
7 changes: 7 additions & 0 deletions logos/animations/story.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def animated_story(request: StoryRequest, callback=None):

music_prompt = screenplay.get("music_prompt")

print("===== animated story =====")
for s in screenplay["clips"]:
print(s)
print("-----")
print(music_prompt)
print("-----")

if callback:
callback(progress=0.1)

Expand Down
25 changes: 15 additions & 10 deletions logos/character.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
story_creation_enabled=False,
concept=None,
smart_reply=False,
chat_model="gpt-3.5-turbo",
chat_model="gpt-4o", #"gpt-3.5-turbo",
image=None,
voice=None,
):
Expand Down Expand Up @@ -100,7 +100,7 @@ def update(
story_creation_enabled=False,
concept=None,
smart_reply=False,
chat_model="gpt-3.5-turbo",
chat_model="gpt-4o", #"gpt-3.5-turbo",
image=None,
voice=None,
):
Expand Down Expand Up @@ -223,7 +223,7 @@ def think(
prompt=user_message,
output_schema=Thought,
save_messages=False,
model="gpt-4-1106-preview",
model="gpt-4o", #"gpt-4-1106-preview",
)

probability = result["probability"]
Expand All @@ -247,7 +247,7 @@ def _route_(
index = self.router(
prompt=router_prompt,
save_messages=False,
model="gpt-4-1106-preview",
model="gpt-4o", #"gpt-4-1106-preview",
)
match = re.match(r"-?\d+", index)
if match:
Expand All @@ -266,7 +266,8 @@ def _chat_(
image=message.attachments[0] if message.attachments else None,
id=session_id,
save_messages=False,
model=self.chat_model,
#model=self.chat_model,
model="gpt-4o",
)
user_message = ChatMessage(role="user", content=message.message)
assistant_message = ChatMessage(role="assistant", content=response)
Expand All @@ -278,7 +279,8 @@ def _qa_(self, message, session_id=None) -> dict:
prompt=message.message,
id=session_id,
save_messages=False,
model=self.chat_model,
#model=self.chat_model,
model="gpt-4o",
)
user_message = ChatMessage(role="user", content=message.message)
assistant_message = ChatMessage(role="assistant", content=response)
Expand All @@ -295,7 +297,8 @@ def _create_(
id=session_id,
input_schema=CreatorInput,
output_schema=CreatorOutput,
model="gpt-4-1106-preview",
#model="gpt-4-1106-preview",
model="gpt-4o",
)

config = {k: v for k, v in response["config"].items() if v}
Expand Down Expand Up @@ -361,7 +364,8 @@ class ContextOutput(BaseModel):
prompt=story_context_prompt,
id=session_id,
output_schema=ContextOutput,
model="gpt-3.5-turbo",
#model="gpt-3.5-turbo",
model="gpt-4o",
)

new_names = [
Expand Down Expand Up @@ -413,7 +417,7 @@ class StoryEditorOutput(BaseModel):
id=session_id,
# input_schema=CreatorInput,
output_schema=StoryEditorOutput,
model="gpt-4-1106-preview",
model="gpt-4o", #"gpt-4-1106-preview",
# model="gpt-3.5-turbo",
)

Expand Down Expand Up @@ -542,7 +546,8 @@ def sync(self):
abilities.get("story_creations", False) if abilities else False
)
smart_reply = abilities.get("smart_reply", False) if abilities else False
chat_model = logos_data.get("chatModel", "gpt-4-1106-preview")
#chat_model = logos_data.get("chatModel", "gpt-4-1106-preview")
chat_model = logos_data.get("chatModel", "gpt-4o")
image = character_data.get("image")
voice = character_data.get("voice")

Expand Down
2 changes: 1 addition & 1 deletion logos/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __call__(
tools: List[Any] = None,
input_schema: Any = None,
output_schema: Any = None,
model: str = "gpt-4-1106-preview",
model: str = "gpt-4o" #"gpt-4-1106-preview",
) -> str:
sess = self.get_session(id)
if tools:
Expand Down
1 change: 1 addition & 0 deletions logos/llm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


ALLOWED_MODELS = [
"gpt-4o",
"gpt-3.5-turbo",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
Expand Down
1 change: 0 additions & 1 deletion logos/models/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

NARRATOR_CHARACTER_ID = os.getenv("NARRATOR_CHARACTER_ID")


class MonologueRequest(BaseModel):
character_id: str
prompt: str
Expand Down
2 changes: 1 addition & 1 deletion logos/scenarios/reel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def reel(request: ReelRequest):
).strip()

reelwriter = LLM(
model=request.model,
model="gpt-4o", #request.model,
system_message=reelwriter_system_template.template,
params=params,
)
Expand Down
26 changes: 24 additions & 2 deletions logos/scenarios/story.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,34 @@ def story(request: StoryRequest):
).strip()

screenwriter = LLM(
model=request.model,
model="gpt-4o", #request.model,
system_message=screenwriter_system_template.template,
params=params,
)

story = screenwriter(prompt, output_schema=StoryResult)
# hack to do story type validation, fixed in eden2
finished = False
tries = 0
max_tries = 5
while not finished:
try:
story = screenwriter(prompt, output_schema=StoryResult)
clip_keys = [c.keys() for c in story["clips"]]
required_keys = {'voiceover', 'character', 'speech', 'image_prompt'}
all_clips_valid = all(required_keys.issubset(set(clip)) for clip in clip_keys)
if not all_clips_valid:
print("Missing keys in clips")
print(clip_keys)
raise ValueError("One or more clips are missing required keys.")
finished = True
break
except Exception as e:
print("Error:", e)
tries += 1
if tries >= max_tries:
raise Exception("Max tries exceeded...")

#story = screenwriter(prompt, output_schema=StoryResult)

if request.music:
if request.music_prompt and request.music_prompt.strip():
Expand Down
1 change: 1 addition & 0 deletions tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ def test_story():
print(response.json())

assert response.status_code == 200

0 comments on commit 8e642e9

Please sign in to comment.