-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1401 from arc53/discord-fix
fix: discord bot
- Loading branch information
Showing
1 changed file
with
101 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,138 @@ | ||
import os | ||
import re | ||
|
||
import logging | ||
import aiohttp | ||
import discord | ||
import requests | ||
from discord.ext import commands | ||
import dotenv | ||
|
||
dotenv.load_dotenv() | ||
|
||
# Replace 'YOUR_BOT_TOKEN' with your bot's token | ||
# Enable logging | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
# Bot configuration | ||
TOKEN = os.getenv("DISCORD_TOKEN") | ||
PREFIX = '@DocsGPT' | ||
BASE_API_URL = 'http://localhost:7091' | ||
PREFIX = '!' # Command prefix | ||
BASE_API_URL = os.getenv("API_BASE", "https://gptcloud.arc53.com") | ||
API_URL = BASE_API_URL + "/api/answer" | ||
API_KEY = os.getenv("API_KEY") | ||
|
||
intents = discord.Intents.default() | ||
intents.message_content = True | ||
|
||
bot = commands.Bot(command_prefix=PREFIX, intents=intents) | ||
|
||
# Store conversation history per user | ||
conversation_histories = {} | ||
|
||
def escape_markdown(text): | ||
"""Escapes Discord markdown characters.""" | ||
escape_chars = r'\*_$$$$()~>#+-=|{}.!' | ||
return re.sub(f'([{re.escape(escape_chars)}])', r'\\\1', text) | ||
|
||
def split_string(input_str): | ||
"""Splits the input string to detect bot mentions.""" | ||
pattern = r'^<@!?{0}>\s*'.format(bot.user.id) | ||
match = re.match(pattern, input_str) | ||
if match: | ||
content = input_str[match.end():].strip() | ||
return str(bot.user.id), content | ||
return None, input_str | ||
|
||
|
||
@bot.event | ||
async def on_ready(): | ||
print(f'{bot.user.name} has connected to Discord!') | ||
|
||
|
||
async def fetch_answer(question): | ||
data = { | ||
'sender': 'discord', | ||
'question': question, | ||
'history': '' | ||
async def generate_answer(question, messages, conversation_id): | ||
"""Generates an answer using the external API.""" | ||
payload = { | ||
"question": question, | ||
"api_key": API_KEY, | ||
"history": messages, | ||
"conversation_id": conversation_id | ||
} | ||
headers = {"Content-Type": "application/json", | ||
"Accept": "application/json"} | ||
response = requests.post(BASE_API_URL + '/api/answer', json=data, headers=headers) | ||
if response.status_code == 200: | ||
return response.json()['answer'] | ||
return 'Sorry, I could not fetch the answer.' | ||
|
||
headers = { | ||
"Content-Type": "application/json; charset=utf-8" | ||
} | ||
timeout = aiohttp.ClientTimeout(total=60) | ||
async with aiohttp.ClientSession(timeout=timeout) as session: | ||
async with session.post(API_URL, json=payload, headers=headers) as resp: | ||
if resp.status == 200: | ||
data = await resp.json() | ||
conversation_id = data.get("conversation_id") | ||
answer = data.get("answer", "Sorry, I couldn't find an answer.") | ||
return {"answer": answer, "conversation_id": conversation_id} | ||
else: | ||
return {"answer": "Sorry, I couldn't find an answer.", "conversation_id": None} | ||
|
||
@bot.command(name="start") | ||
async def start(ctx): | ||
"""Handles the /start command.""" | ||
await ctx.send(f"Hi {ctx.author.mention}! How can I assist you today?") | ||
|
||
@bot.command(name="custom_help") | ||
async def custom_help_command(ctx): | ||
"""Handles the /custom_help command.""" | ||
help_text = ( | ||
"Here are the available commands:\n" | ||
"`!start` - Begin a new conversation with the bot\n" | ||
"`!help` - Display this help message\n\n" | ||
"You can also mention me or send a direct message to ask a question!" | ||
) | ||
await ctx.send(help_text) | ||
|
||
@bot.event | ||
async def on_message(message): | ||
if message.author == bot.user: | ||
return | ||
|
||
content = message.content.strip() | ||
prefix, content = split_string(content) | ||
if prefix is None: | ||
return | ||
|
||
part_prefix = str(bot.user.id) | ||
if part_prefix == prefix: | ||
answer = await fetch_answer(content) | ||
await message.channel.send(answer) | ||
|
||
# Process commands first | ||
await bot.process_commands(message) | ||
|
||
|
||
bot.run(TOKEN) | ||
# Check if the message is in a DM channel | ||
if isinstance(message.channel, discord.DMChannel): | ||
content = message.content.strip() | ||
else: | ||
# In guild channels, check if the message mentions the bot at the start | ||
content = message.content.strip() | ||
prefix, content = split_string(content) | ||
if prefix is None: | ||
return | ||
part_prefix = str(bot.user.id) | ||
if part_prefix != prefix: | ||
return # Bot not mentioned at the start, so do not process | ||
|
||
# Now process the message | ||
user_id = message.author.id | ||
if user_id not in conversation_histories: | ||
conversation_histories[user_id] = { | ||
"history": [], | ||
"conversation_id": None | ||
} | ||
|
||
conversation = conversation_histories[user_id] | ||
conversation["history"].append({"prompt": content}) | ||
|
||
# Generate the answer | ||
response_doc = await generate_answer( | ||
content, | ||
conversation["history"], | ||
conversation["conversation_id"] | ||
) | ||
answer = response_doc["answer"] | ||
conversation_id = response_doc["conversation_id"] | ||
|
||
# Escape markdown characters | ||
answer = escape_markdown(answer) | ||
|
||
await message.channel.send(answer) | ||
|
||
conversation["history"][-1]["response"] = answer | ||
conversation["conversation_id"] = conversation_id | ||
|
||
# Keep conversation history to last 10 exchanges | ||
conversation["history"] = conversation["history"][-10:] | ||
|
||
bot.run(TOKEN) |