From 69ce5b43a380ccc57a3cf3995a33420bddeeb6ec Mon Sep 17 00:00:00 2001
From: hibobmaster <32976627+hibobmaster@users.noreply.github.com>
Date: Tue, 23 Apr 2024 20:18:21 +0800
Subject: [PATCH] Support thread level context (#29)
---
.gitignore | 1 +
CHANGELOG.md | 3 ++
README.md | 17 ++++++----
compose.yaml | 3 +-
src/bot.py | 78 +++++++++++++++++++++++++++++++++++++++++++++
src/gptbot.py | 73 +++++++++++++++++++++++++++++-------------
src/send_message.py | 37 +++++++++++++++------
7 files changed, 173 insertions(+), 39 deletions(-)
diff --git a/.gitignore b/.gitignore
index f0bc170..73a8f5f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -173,3 +173,4 @@ cython_debug/
sync_db
manage_db
element-keys.txt
+context.db
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ef8dcd3..2d90547 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,8 @@
# Changelog
+## 1.7.0
+- Support thread level context
+
## 1.6.0
- Add GPT Vision
diff --git a/README.md b/README.md
index 7b8b1c7..9833638 100644
--- a/README.md
+++ b/README.md
@@ -12,6 +12,7 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
5. Image Generation with [DALLĀ·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
6. GPT Vision(openai or [GPT Vision API](https://platform.openai.com/docs/guides/vision) compatible such as [LocalAI](https://localai.io/features/gpt-vision/))
+7. Room level and thread level chat context
## Installation and Setup
@@ -21,10 +22,10 @@ For explainations and complete parameter list see: https://github.com/hibobmaste
Create two empty file, for persist database only
```bash
-touch sync_db manage_db
+touch sync_db context.db manage_db
sudo docker compose up -d
```
-manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database
+manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database, context.db is for bot chat context
Normal Method:
system dependece: libolm-dev
@@ -115,12 +116,16 @@ LangChain(flowise) admin: https://github.com/hibobmaster/matrix_chatgpt_bot/wiki
![demo2](https://i.imgur.com/BKZktWd.jpg)
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/
+## Thread level Context
+Mention bot with prompt, bot will reply in thread.
+
+To keep context just send prompt in thread directly without mention it.
+
+![thread level context 1](https://i.imgur.com/4vLvNCt.jpeg)
+![thread level context 2](https://i.imgur.com/1eb1Lmd.jpeg)
+
## Thanks
1. [matrix-nio](https://github.com/poljar/matrix-nio)
2. [acheong08](https://github.com/acheong08)
3. [8go](https://github.com/8go/)
-
-
-
-
diff --git a/compose.yaml b/compose.yaml
index 76b61e2..6e31742 100644
--- a/compose.yaml
+++ b/compose.yaml
@@ -12,8 +12,9 @@ services:
# use env file or config.json
# - ./config.json:/app/config.json
# use touch to create empty db file, for persist database only
- # manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database
+ # manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database, context.db is for bot chat context
- ./sync_db:/app/sync_db
+ - ./context.db:/app/context.db
# - ./manage_db:/app/manage_db
# import_keys path
# - ./element-keys.txt:/app/element-keys.txt
diff --git a/src/bot.py b/src/bot.py
index 465389d..24cf640 100644
--- a/src/bot.py
+++ b/src/bot.py
@@ -227,6 +227,8 @@ def __init__(
self.new_prog = re.compile(r"\s*!new\s+(.+)$")
async def close(self, task: asyncio.Task) -> None:
+ self.chatbot.cursor.close()
+ self.chatbot.conn.close()
await self.httpx_client.aclose()
if self.lc_admin is not None:
self.lc_manager.c.close()
@@ -251,6 +253,9 @@ async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> No
# sender_id
sender_id = event.sender
+ # event source
+ event_source = event.source
+
# user_message
raw_user_message = event.body
@@ -265,6 +270,48 @@ async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> No
# remove newline character from event.body
content_body = re.sub("\r\n|\r|\n", " ", raw_user_message)
+ # @bot and reply in thread
+ if "m.mentions" in event_source["content"]:
+ if "user_ids" in event_source["content"]["m.mentions"]:
+ # @bot
+ if (
+ self.user_id
+ in event_source["content"]["m.mentions"]["user_ids"]
+ ):
+ try:
+ asyncio.create_task(
+ self.thread_chat(
+ room_id,
+ reply_to_event_id,
+ sender_id=sender_id,
+ thread_root_id=reply_to_event_id,
+ prompt=content_body,
+ )
+ )
+ except Exception as e:
+ logger.error(e, exe_info=True)
+
+ # thread converstaion
+ if "m.relates_to" in event_source["content"]:
+ if "rel_type" in event_source["content"]["m.relates_to"]:
+ thread_root_id = event_source["content"]["m.relates_to"]["event_id"]
+ # thread is created by @bot
+ if thread_root_id in self.chatbot.conversation:
+ try:
+ asyncio.create_task(
+ self.thread_chat(
+ room_id,
+ reply_to_event_id,
+ sender_id=sender_id,
+ thread_root_id=thread_root_id,
+ prompt=content_body,
+ )
+ )
+ except Exception as e:
+ logger.error(e, exe_info=True)
+
+ # common command
+
# !gpt command
if (
self.openai_api_key is not None
@@ -1300,6 +1347,37 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None:
estr = traceback.format_exc()
logger.info(estr)
+ # thread chat
+ async def thread_chat(
+ self, room_id, reply_to_event_id, thread_root_id, prompt, sender_id
+ ):
+ try:
+ await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
+ content = await self.chatbot.ask_async_v2(
+ prompt=prompt,
+ convo_id=thread_root_id,
+ )
+ await send_room_message(
+ self.client,
+ room_id,
+ reply_message=content,
+ reply_to_event_id=reply_to_event_id,
+ sender_id=sender_id,
+ reply_in_thread=True,
+ thread_root_id=thread_root_id,
+ )
+ except Exception as e:
+ logger.error(e, exe_info=True)
+ await send_room_message(
+ self.client,
+ room_id,
+ reply_message=GENERAL_ERROR_MESSAGE,
+ sender_id=sender_id,
+ reply_to_event_id=reply_to_event_id,
+ reply_in_thread=True,
+ thread_root_id=thread_root_id,
+ )
+
# !chat command
async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message):
try:
diff --git a/src/gptbot.py b/src/gptbot.py
index c9cfed4..88bde58 100644
--- a/src/gptbot.py
+++ b/src/gptbot.py
@@ -2,6 +2,7 @@
Code derived from https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
A simple wrapper for the official ChatGPT API
"""
+import sqlite3
import json
from typing import AsyncGenerator
from tenacity import retry, wait_random_exponential, stop_after_attempt
@@ -9,16 +10,7 @@
import tiktoken
-ENGINES = [
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-16k",
- "gpt-3.5-turbo-0613",
- "gpt-3.5-turbo-16k-0613",
- "gpt-4",
- "gpt-4-32k",
- "gpt-4-0613",
- "gpt-4-32k-0613",
-]
+ENGINES = ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-4-turbo"]
class Chatbot:
@@ -41,6 +33,7 @@ def __init__(
reply_count: int = 1,
truncate_limit: int = None,
system_prompt: str = None,
+ db_path: str = "context.db",
) -> None:
"""
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
@@ -53,23 +46,24 @@ def __init__(
or "You are ChatGPT, \
a large language model trained by OpenAI. Respond conversationally"
)
+ # https://platform.openai.com/docs/models
self.max_tokens: int = max_tokens or (
- 31000
+ 127000
+ if "gpt-4-turbo" in engine
+ else 31000
if "gpt-4-32k" in engine
else 7000
if "gpt-4" in engine
- else 15000
- if "gpt-3.5-turbo-16k" in engine
- else 4000
+ else 16000
)
self.truncate_limit: int = truncate_limit or (
- 30500
+ 126500
+ if "gpt-4-turbo" in engine
+ else 30500
if "gpt-4-32k" in engine
else 6500
if "gpt-4" in engine
- else 14500
- if "gpt-3.5-turbo-16k" in engine
- else 3500
+ else 15500
)
self.temperature: float = temperature
self.top_p: float = top_p
@@ -80,17 +74,49 @@ def __init__(
self.aclient = aclient
- self.conversation: dict[str, list[dict]] = {
+ self.db_path = db_path
+
+ self.conn = sqlite3.connect(self.db_path)
+ self.cursor = self.conn.cursor()
+
+ self._create_tables()
+
+ self.conversation = self._load_conversation()
+
+ if self.get_token_count("default") > self.max_tokens:
+ raise Exception("System prompt is too long")
+
+ def _create_tables(self) -> None:
+ self.conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS conversations(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ convo_id TEXT UNIQUE,
+ messages TEXT
+ )
+ """
+ )
+
+ def _load_conversation(self) -> dict[str, list[dict]]:
+ conversations: dict[str, list[dict]] = {
"default": [
{
"role": "system",
- "content": system_prompt,
+ "content": self.system_prompt,
},
],
}
+ self.cursor.execute("SELECT convo_id, messages FROM conversations")
+ for convo_id, messages in self.cursor.fetchall():
+ conversations[convo_id] = json.loads(messages)
+ return conversations
- if self.get_token_count("default") > self.max_tokens:
- raise Exception("System prompt is too long")
+ def _save_conversation(self, convo_id) -> None:
+ self.conn.execute(
+ "INSERT OR REPLACE INTO conversations (convo_id, messages) VALUES (?, ?)",
+ (convo_id, json.dumps(self.conversation[convo_id])),
+ )
+ self.conn.commit()
def add_to_conversation(
self,
@@ -102,6 +128,7 @@ def add_to_conversation(
Add a message to the conversation
"""
self.conversation[convo_id].append({"role": role, "content": message})
+ self._save_conversation(convo_id)
def __truncate_conversation(self, convo_id: str = "default") -> None:
"""
@@ -116,6 +143,7 @@ def __truncate_conversation(self, convo_id: str = "default") -> None:
self.conversation[convo_id].pop(1)
else:
break
+ self._save_conversation(convo_id)
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def get_token_count(self, convo_id: str = "default") -> int:
@@ -305,6 +333,7 @@ def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
self.conversation[convo_id] = [
{"role": "system", "content": system_prompt or self.system_prompt},
]
+ self._save_conversation(convo_id)
@retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(3))
async def oneTimeAsk(
diff --git a/src/send_message.py b/src/send_message.py
index 26179d6..7ee5b60 100644
--- a/src/send_message.py
+++ b/src/send_message.py
@@ -12,6 +12,8 @@ async def send_room_message(
sender_id: str = "",
user_message: str = "",
reply_to_event_id: str = "",
+ reply_in_thread: bool = False,
+ thread_root_id: str = "",
) -> None:
if reply_to_event_id == "":
content = {
@@ -23,6 +25,23 @@ async def send_room_message(
extensions=["nl2br", "tables", "fenced_code"],
),
}
+ elif reply_in_thread and thread_root_id:
+ content = {
+ "msgtype": "m.text",
+ "body": reply_message,
+ "format": "org.matrix.custom.html",
+ "formatted_body": markdown.markdown(
+ reply_message,
+ extensions=["nl2br", "tables", "fenced_code"],
+ ),
+ "m.relates_to": {
+ "m.in_reply_to": {"event_id": reply_to_event_id},
+ "rel_type": "m.thread",
+ "event_id": thread_root_id,
+ "is_falling_back": True,
+ },
+ }
+
else:
body = "> <" + sender_id + "> " + user_message + "\n\n" + reply_message
format = r"org.matrix.custom.html"
@@ -51,13 +70,11 @@ async def send_room_message(
"formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}},
}
- try:
- await client.room_send(
- room_id,
- message_type="m.room.message",
- content=content,
- ignore_unverified_devices=True,
- )
- await client.room_typing(room_id, typing_state=False)
- except Exception as e:
- logger.error(e)
+
+ await client.room_send(
+ room_id,
+ message_type="m.room.message",
+ content=content,
+ ignore_unverified_devices=True,
+ )
+ await client.room_typing(room_id, typing_state=False)