diff --git a/src/bots/eden-stg/.env.example b/src/bots/eden-stg/.env.example index a38084d..004dd0c 100644 --- a/src/bots/eden-stg/.env.example +++ b/src/bots/eden-stg/.env.example @@ -11,4 +11,4 @@ ALLOWED_GUILDS= ALLOWED_GUILDS_TEST= ALLOWED_CHANNELS= -OPENAI_API_KEY= +LOGOS_URL= diff --git a/src/bots/eden-stg/EdenCharacterCog.py b/src/bots/eden-stg/EdenLogosCharacterCog.py similarity index 51% rename from src/bots/eden-stg/EdenCharacterCog.py rename to src/bots/eden-stg/EdenLogosCharacterCog.py index 23c4b43..ecd2f7c 100644 --- a/src/bots/eden-stg/EdenCharacterCog.py +++ b/src/bots/eden-stg/EdenLogosCharacterCog.py @@ -1,11 +1,11 @@ from discord.ext import commands -from cogs.CharacterCog import CharacterCog +from cogs.LogosCharacterCog import LogosCharacterCog -class EdenCharacterCog(CharacterCog): +class EdenLogosCharacterCog(LogosCharacterCog): def __init__(self, bot: commands.bot) -> None: super().__init__(bot) def setup(bot: commands.Bot) -> None: - bot.add_cog(EdenCharacterCog(bot)) + bot.add_cog(EdenLogosCharacterCog(bot)) diff --git a/src/cogs/LogosCharacterCog.py b/src/cogs/LogosCharacterCog.py new file mode 100644 index 0000000..119e452 --- /dev/null +++ b/src/cogs/LogosCharacterCog.py @@ -0,0 +1,145 @@ +import os +import random +import discord +from attr import dataclass +from discord.ext import commands + +from common.discord import ( + get_source, + is_mentioned, + replace_bot_mention, + replace_mentions_with_usernames, +) +from common.eden import generation_loop, get_assistant +from common.logos import request_logos_assistant +from common.models import ( + GenerationLoopInput, + SignInCredentials, + StableDiffusionConfig, +) + +ALLOWED_CHANNELS = [int(c) for c in os.getenv("ALLOWED_CHANNELS", "").split(",")] + +EDEN_API_URL = os.getenv("EDEN_API_URL") +LOGOS_URL = os.getenv("LOGOS_API_URL") +EDEN_FRONTEND_URL = EDEN_API_URL.replace("api", "app") +EDEN_API_KEY = os.getenv("EDEN_API_KEY") +EDEN_API_SECRET = os.getenv("EDEN_API_SECRET") +EDEN_CHARACTER_ID = os.getenv("EDEN_CHARACTER_ID") + + +@dataclass +class LoraInput: + lora_id: str + lora_strength: float + lora_trigger: str + require_lora_trigger: bool + + +class LogosCharacterCog(commands.Cog): + def __init__( + self, + bot: commands.bot, + ) -> None: + print("LogosCharacterCog init...") + self.bot = bot + self.eden_credentials = SignInCredentials( + apiKey=EDEN_API_KEY, apiSecret=EDEN_API_SECRET + ) + self.characterId = EDEN_CHARACTER_ID + + @commands.Cog.listener("on_message") + async def on_message(self, message: discord.Message) -> None: + if ( + message.channel.id not in ALLOWED_CHANNELS + or message.author.id == self.bot.user.id + or message.author.bot + ): + return + + trigger_reply = is_mentioned(message, self.bot.user) + + if trigger_reply: + ctx = await self.bot.get_context(message) + async with ctx.channel.typing(): + prompt = self.message_preprocessor(message) + attachment_urls = [attachment.url for attachment in message.attachments] + + assistant, concept = get_assistant( + api_url=EDEN_API_URL, + character_id=self.characterId, + credentials=self.eden_credentials, + ) + + interaction = { + "prompt": prompt, + "attachments": attachment_urls, + "author_id": str(message.author.id), + } + + response = request_logos_assistant(LOGOS_URL, assistant, interaction) + reply = response.get("message")[:2000] + reply_message = await message.reply(reply) + + # check if there is a config + config = response.get("config") + + if not config: + return + + mode = config.pop("generator") + + if config.get("text_input"): + text_input = config["text_input"] + elif config.get("interpolation_texts"): + text_input = " to ".join(config["interpolation_texts"]) + else: + text_input = mode + + if not config.get("seed"): + config["seed"] = random.randint(1, 1e8) + + config = StableDiffusionConfig(generator_name=mode, **config) + + config = self.add_lora(config, concept) + + source = get_source(ctx) + + is_video_request = mode in ["interpolate", "real2real"] + + start_bot_message = f"**{text_input}** - <@!{ctx.author.id}>\n" + original_text = ( + f"{reply[0:1950-len(start_bot_message)]}\n\n{start_bot_message}" + ) + + generation_loop_input = GenerationLoopInput( + api_url=EDEN_API_URL, + frontend_url=EDEN_FRONTEND_URL, + message=reply_message, + start_bot_message=original_text, + source=source, + config=config, + prefer_gif=False, + is_video_request=is_video_request, + ) + await generation_loop( + generation_loop_input, eden_credentials=self.eden_credentials + ) + + def message_preprocessor(self, message: discord.Message) -> str: + message_content = replace_bot_mention(message.content, only_first=True) + message_content = replace_mentions_with_usernames( + message_content, + message.mentions, + ) + message_content = message_content.strip() + return message_content + + def add_lora(self, config: StableDiffusionConfig, concept: str): + if concept: + config.lora = concept + config.lora_strength = 0.6 + return config + + def check_lora_trigger_provided(message: str, lora_trigger: str): + return lora_trigger in message diff --git a/src/common/eden.py b/src/common/eden.py index 526bd3a..839c74f 100644 --- a/src/common/eden.py +++ b/src/common/eden.py @@ -2,7 +2,6 @@ import io import os from typing import Optional -from logos.scenarios import EdenAssistant import aiohttp @@ -308,21 +307,18 @@ def get_assistant(api_url: str, character_id: str, credentials: SignInCredential "x-api-secret": credentials.apiSecret, } - print(f"{api_url}/characters/{character_id}") response = requests.get(f"{api_url}/characters/{character_id}", headers=header) json = response.json() - print(json) character = json.get("character") description = character.get("description") logosData = character.get("logosData") - - assistant = EdenAssistant( - character_description=description, - creator_prompt=logosData.get("creatorPrompt"), - documentation_prompt=logosData.get("documentationPrompt"), - documentation=logosData.get("documentation"), - router_prompt=logosData.get("routerPrompt"), - ) + assistant = { + "character_description": description, + "creator_prompt": logosData.get("creatorPrompt"), + "documentation_prompt": logosData.get("documentationPrompt"), + "documentation": logosData.get("documentation"), + "router_prompt": logosData.get("routerPrompt"), + } concept = character.get("concept") return assistant, concept diff --git a/src/common/logos.py b/src/common/logos.py new file mode 100644 index 0000000..cd9ff2a --- /dev/null +++ b/src/common/logos.py @@ -0,0 +1,14 @@ +import requests + + +def request_logos_assistant(api_url: str, assistant: dict, interaction: dict): + response = requests.post( + f"{api_url}/interact", + json={ + "assistant": assistant, + "interaction": interaction, + }, + ) + print(response.json()) + + return response.json()