From 8e642e9baa4be4961385d06e9fe2b5a8c868ff85 Mon Sep 17 00:00:00 2001 From: genekogan Date: Mon, 17 Jun 2024 22:29:15 -0400 Subject: [PATCH] story validation hack, gpt-4o --- logos/animations/story.py | 7 +++++++ logos/character.py | 25 +++++++++++++++---------- logos/llm/llm.py | 2 +- logos/llm/session.py | 1 + logos/models/scenarios.py | 1 - logos/scenarios/reel.py | 2 +- logos/scenarios/story.py | 26 ++++++++++++++++++++++++-- tests/test_stories.py | 1 + 8 files changed, 50 insertions(+), 15 deletions(-) diff --git a/logos/animations/story.py b/logos/animations/story.py index d82dcf3..eb4f52c 100644 --- a/logos/animations/story.py +++ b/logos/animations/story.py @@ -19,6 +19,13 @@ def animated_story(request: StoryRequest, callback=None): music_prompt = screenplay.get("music_prompt") + print("===== animated story =====") + for s in screenplay["clips"]: + print(s) + print("-----") + print(music_prompt) + print("-----") + if callback: callback(progress=0.1) diff --git a/logos/character.py b/logos/character.py index 8f9453f..bb01196 100644 --- a/logos/character.py +++ b/logos/character.py @@ -54,7 +54,7 @@ def __init__( story_creation_enabled=False, concept=None, smart_reply=False, - chat_model="gpt-3.5-turbo", + chat_model="gpt-4o", #"gpt-3.5-turbo", image=None, voice=None, ): @@ -100,7 +100,7 @@ def update( story_creation_enabled=False, concept=None, smart_reply=False, - chat_model="gpt-3.5-turbo", + chat_model="gpt-4o", #"gpt-3.5-turbo", image=None, voice=None, ): @@ -223,7 +223,7 @@ def think( prompt=user_message, output_schema=Thought, save_messages=False, - model="gpt-4-1106-preview", + model="gpt-4o", #"gpt-4-1106-preview", ) probability = result["probability"] @@ -247,7 +247,7 @@ def _route_( index = self.router( prompt=router_prompt, save_messages=False, - model="gpt-4-1106-preview", + model="gpt-4o", #"gpt-4-1106-preview", ) match = re.match(r"-?\d+", index) if match: @@ -266,7 +266,8 @@ def _chat_( image=message.attachments[0] if message.attachments else None, id=session_id, save_messages=False, - model=self.chat_model, + #model=self.chat_model, + model="gpt-4o", ) user_message = ChatMessage(role="user", content=message.message) assistant_message = ChatMessage(role="assistant", content=response) @@ -278,7 +279,8 @@ def _qa_(self, message, session_id=None) -> dict: prompt=message.message, id=session_id, save_messages=False, - model=self.chat_model, + #model=self.chat_model, + model="gpt-4o", ) user_message = ChatMessage(role="user", content=message.message) assistant_message = ChatMessage(role="assistant", content=response) @@ -295,7 +297,8 @@ def _create_( id=session_id, input_schema=CreatorInput, output_schema=CreatorOutput, - model="gpt-4-1106-preview", + #model="gpt-4-1106-preview", + model="gpt-4o", ) config = {k: v for k, v in response["config"].items() if v} @@ -361,7 +364,8 @@ class ContextOutput(BaseModel): prompt=story_context_prompt, id=session_id, output_schema=ContextOutput, - model="gpt-3.5-turbo", + #model="gpt-3.5-turbo", + model="gpt-4o", ) new_names = [ @@ -413,7 +417,7 @@ class StoryEditorOutput(BaseModel): id=session_id, # input_schema=CreatorInput, output_schema=StoryEditorOutput, - model="gpt-4-1106-preview", + model="gpt-4o", #"gpt-4-1106-preview", # model="gpt-3.5-turbo", ) @@ -542,7 +546,8 @@ def sync(self): abilities.get("story_creations", False) if abilities else False ) smart_reply = abilities.get("smart_reply", False) if abilities else False - chat_model = logos_data.get("chatModel", "gpt-4-1106-preview") + #chat_model = logos_data.get("chatModel", "gpt-4-1106-preview") + chat_model = logos_data.get("chatModel", "gpt-4o") image = character_data.get("image") voice = character_data.get("voice") diff --git a/logos/llm/llm.py b/logos/llm/llm.py index f4a231e..fa09f44 100644 --- a/logos/llm/llm.py +++ b/logos/llm/llm.py @@ -155,7 +155,7 @@ def __call__( tools: List[Any] = None, input_schema: Any = None, output_schema: Any = None, - model: str = "gpt-4-1106-preview", + model: str = "gpt-4o" #"gpt-4-1106-preview", ) -> str: sess = self.get_session(id) if tools: diff --git a/logos/llm/session.py b/logos/llm/session.py index e3e93e3..eac568d 100644 --- a/logos/llm/session.py +++ b/logos/llm/session.py @@ -13,6 +13,7 @@ ALLOWED_MODELS = [ + "gpt-4o", "gpt-3.5-turbo", "gpt-4-1106-preview", "gpt-4-vision-preview", diff --git a/logos/models/scenarios.py b/logos/models/scenarios.py index 17bac15..b98e2fb 100644 --- a/logos/models/scenarios.py +++ b/logos/models/scenarios.py @@ -7,7 +7,6 @@ NARRATOR_CHARACTER_ID = os.getenv("NARRATOR_CHARACTER_ID") - class MonologueRequest(BaseModel): character_id: str prompt: str diff --git a/logos/scenarios/reel.py b/logos/scenarios/reel.py index 6194927..168e26d 100644 --- a/logos/scenarios/reel.py +++ b/logos/scenarios/reel.py @@ -43,7 +43,7 @@ def reel(request: ReelRequest): ).strip() reelwriter = LLM( - model=request.model, + model="gpt-4o", #request.model, system_message=reelwriter_system_template.template, params=params, ) diff --git a/logos/scenarios/story.py b/logos/scenarios/story.py index b1995d3..e24e2a5 100644 --- a/logos/scenarios/story.py +++ b/logos/scenarios/story.py @@ -34,12 +34,34 @@ def story(request: StoryRequest): ).strip() screenwriter = LLM( - model=request.model, + model="gpt-4o", #request.model, system_message=screenwriter_system_template.template, params=params, ) - story = screenwriter(prompt, output_schema=StoryResult) + # hack to do story type validation, fixed in eden2 + finished = False + tries = 0 + max_tries = 5 + while not finished: + try: + story = screenwriter(prompt, output_schema=StoryResult) + clip_keys = [c.keys() for c in story["clips"]] + required_keys = {'voiceover', 'character', 'speech', 'image_prompt'} + all_clips_valid = all(required_keys.issubset(set(clip)) for clip in clip_keys) + if not all_clips_valid: + print("Missing keys in clips") + print(clip_keys) + raise ValueError("One or more clips are missing required keys.") + finished = True + break + except Exception as e: + print("Error:", e) + tries += 1 + if tries >= max_tries: + raise Exception("Max tries exceeded...") + + #story = screenwriter(prompt, output_schema=StoryResult) if request.music: if request.music_prompt and request.music_prompt.strip(): diff --git a/tests/test_stories.py b/tests/test_stories.py index dce795b..56fb176 100644 --- a/tests/test_stories.py +++ b/tests/test_stories.py @@ -42,3 +42,4 @@ def test_story(): print(response.json()) assert response.status_code == 200 +