From e8e0de627f0b39f209eb178422b4af17dd82e1ba Mon Sep 17 00:00:00 2001 From: genekogan Date: Thu, 4 Jul 2024 11:44:33 +0100 Subject: [PATCH] new thread --- src/cogs/Eden2Cog.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/cogs/Eden2Cog.py b/src/cogs/Eden2Cog.py index 7b39332..1b77224 100644 --- a/src/cogs/Eden2Cog.py +++ b/src/cogs/Eden2Cog.py @@ -14,6 +14,8 @@ replace_mentions_with_usernames, ) +long_running_tools = ["txt2vid", "style_mixing", "img2vid", "vid2vid", "video_upscale"] + # ALLOWED_CHANNELS = [int(c) for c in os.getenv("ALLOWED_CHANNELS", "").split(",")] EDEN_CHARACTER_ID = os.getenv("EDEN_CHARACTER_ID") @@ -29,8 +31,6 @@ async def button_click(self, button: ui.Button, interaction: discord.Interaction - - class Eden2Cog(commands.Cog): def __init__( self, @@ -38,36 +38,34 @@ def __init__( ) -> None: self.bot = bot self.characterId = EDEN_CHARACTER_ID - self.thread_id = client.get_or_create_thread("discord-test3") + self.thread_id = client.get_or_create_thread("discord-test11") print("thread id", self.thread_id) @commands.Cog.listener("on_message") async def on_message(self, message: discord.Message) -> None: - print("on... message ...", message.content) + # print("on... message ...", message.content) if ( message.author.id == self.bot.user.id or message.author.bot ): return - print("check if mention") trigger_reply = is_mentioned(message, self.bot.user) - print("trig reply", trigger_reply) if not trigger_reply: return if message.channel.id != 1186378591118839808 and message.channel.id != 1006143747588898849: return - print("got here..", message.content) content = replace_bot_mention(message.content, only_first=True) content = replace_mentions_with_usernames(content, message.mentions) - print("content", content) + # Check if the message is a reply if message.reference: source_message = await message.channel.fetch_message(message.reference.message_id) # content = f"((Reply to {source_message.author.name}: {source_message.content[:120]} ...))\n\n{content}" - content = f"((Reply to {source_message.author.name}: {source_message.content[:50]} ...))\n\n{content}" + # content = f"(Reply to {source_message.author.name}: {source_message.content[:50]} ...))\n\n{content}" + content = f"(Replying to message: {source_message.content[:100]} ...)\n\n{content}" # TODO: extract urls don't shorten them chat_message = { @@ -79,23 +77,29 @@ async def on_message(self, message: discord.Message) -> None: ctx = await self.bot.get_context(message) async with ctx.channel.typing(): - print(chat_message) + import random + ran = random.randint(1, 10000) + print(ran, content) + # print(chat_message) async for response in client.async_chat(chat_message, self.thread_id): - print(response) - error = response.get("error") - if error: - await reply(message, error) + print(ran, response) + error_message = response.get("error") + if error_message: + await reply(message, error_message) continue response = json.loads(response.get("message")) content = response.get("content") tool_calls = response.get("tool_calls") if tool_calls: tool_name = tool_calls[0].get("function").get("name") - if tool_name in ["txt2vid", "style_mixing", "img2vid", "vid2vid", "video_upscale"]: + if tool_name in long_running_tools: args = json.loads(tool_calls[0].get("function").get("arguments")) prompt = args.get("prompt") - content = f"Running {tool_name}: {prompt}. Please wait..." - await reply(message, content) + if prompt: + await reply(message, f"Running {tool_name}: {prompt}. Please wait...") + else: + await reply(message, f"Running {tool_name}. Please wait...") + if content: await reply(message, content)