Skip to content

Commit

Permalink
new thread
Browse files Browse the repository at this point in the history
  • Loading branch information
genekogan committed Jul 4, 2024
1 parent f690370 commit e8e0de6
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions src/cogs/Eden2Cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
replace_mentions_with_usernames,
)

long_running_tools = ["txt2vid", "style_mixing", "img2vid", "vid2vid", "video_upscale"]

# ALLOWED_CHANNELS = [int(c) for c in os.getenv("ALLOWED_CHANNELS", "").split(",")]
EDEN_CHARACTER_ID = os.getenv("EDEN_CHARACTER_ID")

Expand All @@ -29,45 +31,41 @@ async def button_click(self, button: ui.Button, interaction: discord.Interaction





class Eden2Cog(commands.Cog):
def __init__(
self,
bot: commands.bot,
) -> None:
self.bot = bot
self.characterId = EDEN_CHARACTER_ID
self.thread_id = client.get_or_create_thread("discord-test3")
self.thread_id = client.get_or_create_thread("discord-test11")
print("thread id", self.thread_id)

@commands.Cog.listener("on_message")
async def on_message(self, message: discord.Message) -> None:
print("on... message ...", message.content)
# print("on... message ...", message.content)
if (
message.author.id == self.bot.user.id
or message.author.bot
):
return

print("check if mention")
trigger_reply = is_mentioned(message, self.bot.user)
print("trig reply", trigger_reply)
if not trigger_reply:
return

if message.channel.id != 1186378591118839808 and message.channel.id != 1006143747588898849:
return

print("got here..", message.content)
content = replace_bot_mention(message.content, only_first=True)
content = replace_mentions_with_usernames(content, message.mentions)
print("content", content)

# Check if the message is a reply
if message.reference:
source_message = await message.channel.fetch_message(message.reference.message_id)
# content = f"((Reply to {source_message.author.name}: {source_message.content[:120]} ...))\n\n{content}"
content = f"((Reply to {source_message.author.name}: {source_message.content[:50]} ...))\n\n{content}"
# content = f"(Reply to {source_message.author.name}: {source_message.content[:50]} ...))\n\n{content}"
content = f"(Replying to message: {source_message.content[:100]} ...)\n\n{content}"
# TODO: extract urls don't shorten them

chat_message = {
Expand All @@ -79,23 +77,29 @@ async def on_message(self, message: discord.Message) -> None:

ctx = await self.bot.get_context(message)
async with ctx.channel.typing():
print(chat_message)
import random
ran = random.randint(1, 10000)
print(ran, content)
# print(chat_message)
async for response in client.async_chat(chat_message, self.thread_id):
print(response)
error = response.get("error")
if error:
await reply(message, error)
print(ran, response)
error_message = response.get("error")
if error_message:
await reply(message, error_message)
continue
response = json.loads(response.get("message"))
content = response.get("content")
tool_calls = response.get("tool_calls")
if tool_calls:
tool_name = tool_calls[0].get("function").get("name")
if tool_name in ["txt2vid", "style_mixing", "img2vid", "vid2vid", "video_upscale"]:
if tool_name in long_running_tools:
args = json.loads(tool_calls[0].get("function").get("arguments"))
prompt = args.get("prompt")
content = f"Running {tool_name}: {prompt}. Please wait..."
await reply(message, content)
if prompt:
await reply(message, f"Running {tool_name}: {prompt}. Please wait...")
else:
await reply(message, f"Running {tool_name}. Please wait...")

if content:
await reply(message, content)

Expand Down

0 comments on commit e8e0de6

Please sign in to comment.