Skip to content

Commit

Permalink
Merge pull request #7 from edenartlab/feat/logos-svc
Browse files Browse the repository at this point in the history
Feat/logos svc
  • Loading branch information
jmilldotdev authored Dec 16, 2023
2 parents d1ee3d7 + 1836c2c commit b0b64bb
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/bots/eden-stg/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ ALLOWED_GUILDS=
ALLOWED_GUILDS_TEST=
ALLOWED_CHANNELS=

OPENAI_API_KEY=
LOGOS_URL=
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from discord.ext import commands
from cogs.CharacterCog import CharacterCog
from cogs.LogosCharacterCog import LogosCharacterCog


class EdenCharacterCog(CharacterCog):
class EdenLogosCharacterCog(LogosCharacterCog):
def __init__(self, bot: commands.bot) -> None:
super().__init__(bot)


def setup(bot: commands.Bot) -> None:
bot.add_cog(EdenCharacterCog(bot))
bot.add_cog(EdenLogosCharacterCog(bot))
145 changes: 145 additions & 0 deletions src/cogs/LogosCharacterCog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import os
import random
import discord
from attr import dataclass
from discord.ext import commands

from common.discord import (
get_source,
is_mentioned,
replace_bot_mention,
replace_mentions_with_usernames,
)
from common.eden import generation_loop, get_assistant
from common.logos import request_logos_assistant
from common.models import (
GenerationLoopInput,
SignInCredentials,
StableDiffusionConfig,
)

ALLOWED_CHANNELS = [int(c) for c in os.getenv("ALLOWED_CHANNELS", "").split(",")]

EDEN_API_URL = os.getenv("EDEN_API_URL")
LOGOS_URL = os.getenv("LOGOS_API_URL")
EDEN_FRONTEND_URL = EDEN_API_URL.replace("api", "app")
EDEN_API_KEY = os.getenv("EDEN_API_KEY")
EDEN_API_SECRET = os.getenv("EDEN_API_SECRET")
EDEN_CHARACTER_ID = os.getenv("EDEN_CHARACTER_ID")


@dataclass
class LoraInput:
lora_id: str
lora_strength: float
lora_trigger: str
require_lora_trigger: bool


class LogosCharacterCog(commands.Cog):
def __init__(
self,
bot: commands.bot,
) -> None:
print("LogosCharacterCog init...")
self.bot = bot
self.eden_credentials = SignInCredentials(
apiKey=EDEN_API_KEY, apiSecret=EDEN_API_SECRET
)
self.characterId = EDEN_CHARACTER_ID

@commands.Cog.listener("on_message")
async def on_message(self, message: discord.Message) -> None:
if (
message.channel.id not in ALLOWED_CHANNELS
or message.author.id == self.bot.user.id
or message.author.bot
):
return

trigger_reply = is_mentioned(message, self.bot.user)

if trigger_reply:
ctx = await self.bot.get_context(message)
async with ctx.channel.typing():
prompt = self.message_preprocessor(message)
attachment_urls = [attachment.url for attachment in message.attachments]

assistant, concept = get_assistant(
api_url=EDEN_API_URL,
character_id=self.characterId,
credentials=self.eden_credentials,
)

interaction = {
"prompt": prompt,
"attachments": attachment_urls,
"author_id": str(message.author.id),
}

response = request_logos_assistant(LOGOS_URL, assistant, interaction)
reply = response.get("message")[:2000]
reply_message = await message.reply(reply)

# check if there is a config
config = response.get("config")

if not config:
return

mode = config.pop("generator")

if config.get("text_input"):
text_input = config["text_input"]
elif config.get("interpolation_texts"):
text_input = " to ".join(config["interpolation_texts"])
else:
text_input = mode

if not config.get("seed"):
config["seed"] = random.randint(1, 1e8)

config = StableDiffusionConfig(generator_name=mode, **config)

config = self.add_lora(config, concept)

source = get_source(ctx)

is_video_request = mode in ["interpolate", "real2real"]

start_bot_message = f"**{text_input}** - <@!{ctx.author.id}>\n"
original_text = (
f"{reply[0:1950-len(start_bot_message)]}\n\n{start_bot_message}"
)

generation_loop_input = GenerationLoopInput(
api_url=EDEN_API_URL,
frontend_url=EDEN_FRONTEND_URL,
message=reply_message,
start_bot_message=original_text,
source=source,
config=config,
prefer_gif=False,
is_video_request=is_video_request,
)
await generation_loop(
generation_loop_input, eden_credentials=self.eden_credentials
)

def message_preprocessor(self, message: discord.Message) -> str:
message_content = replace_bot_mention(message.content, only_first=True)
message_content = replace_mentions_with_usernames(
message_content,
message.mentions,
)
message_content = message_content.strip()
return message_content

def add_lora(self, config: StableDiffusionConfig, concept: str):
if concept:
config.lora = concept
config.lora_strength = 0.6
return config

def check_lora_trigger_provided(message: str, lora_trigger: str):
return lora_trigger in message
18 changes: 7 additions & 11 deletions src/common/eden.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import io
import os
from typing import Optional
from logos.scenarios import EdenAssistant


import aiohttp
Expand Down Expand Up @@ -308,21 +307,18 @@ def get_assistant(api_url: str, character_id: str, credentials: SignInCredential
"x-api-secret": credentials.apiSecret,
}

print(f"{api_url}/characters/{character_id}")
response = requests.get(f"{api_url}/characters/{character_id}", headers=header)
json = response.json()
print(json)
character = json.get("character")
description = character.get("description")
logosData = character.get("logosData")

assistant = EdenAssistant(
character_description=description,
creator_prompt=logosData.get("creatorPrompt"),
documentation_prompt=logosData.get("documentationPrompt"),
documentation=logosData.get("documentation"),
router_prompt=logosData.get("routerPrompt"),
)
assistant = {
"character_description": description,
"creator_prompt": logosData.get("creatorPrompt"),
"documentation_prompt": logosData.get("documentationPrompt"),
"documentation": logosData.get("documentation"),
"router_prompt": logosData.get("routerPrompt"),
}
concept = character.get("concept")

return assistant, concept
14 changes: 14 additions & 0 deletions src/common/logos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import requests


def request_logos_assistant(api_url: str, assistant: dict, interaction: dict):
response = requests.post(
f"{api_url}/interact",
json={
"assistant": assistant,
"interaction": interaction,
},
)
print(response.json())

return response.json()

0 comments on commit b0b64bb

Please sign in to comment.