diff --git a/.gitignore b/.gitignore index f0bc170..73a8f5f 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,4 @@ cython_debug/ sync_db manage_db element-keys.txt +context.db diff --git a/CHANGELOG.md b/CHANGELOG.md index ef8dcd3..2d90547 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 1.7.0 +- Support thread level context + ## 1.6.0 - Add GPT Vision diff --git a/README.md b/README.md index 7b8b1c7..9833638 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate 4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise)) 5. Image Generation with [DALLĀ·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API) 6. GPT Vision(openai or [GPT Vision API](https://platform.openai.com/docs/guides/vision) compatible such as [LocalAI](https://localai.io/features/gpt-vision/)) +7. Room level and thread level chat context ## Installation and Setup @@ -21,10 +22,10 @@ For explainations and complete parameter list see: https://github.com/hibobmaste Create two empty file, for persist database only
```bash -touch sync_db manage_db +touch sync_db context.db manage_db sudo docker compose up -d ``` -manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database
+manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database, context.db is for bot chat context

Normal Method:
system dependece: libolm-dev @@ -115,12 +116,16 @@ LangChain(flowise) admin: https://github.com/hibobmaster/matrix_chatgpt_bot/wiki ![demo2](https://i.imgur.com/BKZktWd.jpg) https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/
+## Thread level Context +Mention bot with prompt, bot will reply in thread. + +To keep context just send prompt in thread directly without mention it. + +![thread level context 1](https://i.imgur.com/4vLvNCt.jpeg) +![thread level context 2](https://i.imgur.com/1eb1Lmd.jpeg) + ## Thanks 1. [matrix-nio](https://github.com/poljar/matrix-nio) 2. [acheong08](https://github.com/acheong08) 3. [8go](https://github.com/8go/) - - -JetBrains Logo (Main) logo. - diff --git a/compose.yaml b/compose.yaml index 76b61e2..6e31742 100644 --- a/compose.yaml +++ b/compose.yaml @@ -12,8 +12,9 @@ services: # use env file or config.json # - ./config.json:/app/config.json # use touch to create empty db file, for persist database only - # manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database + # manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database, context.db is for bot chat context - ./sync_db:/app/sync_db + - ./context.db:/app/context.db # - ./manage_db:/app/manage_db # import_keys path # - ./element-keys.txt:/app/element-keys.txt diff --git a/src/bot.py b/src/bot.py index 465389d..24cf640 100644 --- a/src/bot.py +++ b/src/bot.py @@ -227,6 +227,8 @@ def __init__( self.new_prog = re.compile(r"\s*!new\s+(.+)$") async def close(self, task: asyncio.Task) -> None: + self.chatbot.cursor.close() + self.chatbot.conn.close() await self.httpx_client.aclose() if self.lc_admin is not None: self.lc_manager.c.close() @@ -251,6 +253,9 @@ async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> No # sender_id sender_id = event.sender + # event source + event_source = event.source + # user_message raw_user_message = event.body @@ -265,6 +270,48 @@ async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> No # remove newline character from event.body content_body = re.sub("\r\n|\r|\n", " ", raw_user_message) + # @bot and reply in thread + if "m.mentions" in event_source["content"]: + if "user_ids" in event_source["content"]["m.mentions"]: + # @bot + if ( + self.user_id + in event_source["content"]["m.mentions"]["user_ids"] + ): + try: + asyncio.create_task( + self.thread_chat( + room_id, + reply_to_event_id, + sender_id=sender_id, + thread_root_id=reply_to_event_id, + prompt=content_body, + ) + ) + except Exception as e: + logger.error(e, exe_info=True) + + # thread converstaion + if "m.relates_to" in event_source["content"]: + if "rel_type" in event_source["content"]["m.relates_to"]: + thread_root_id = event_source["content"]["m.relates_to"]["event_id"] + # thread is created by @bot + if thread_root_id in self.chatbot.conversation: + try: + asyncio.create_task( + self.thread_chat( + room_id, + reply_to_event_id, + sender_id=sender_id, + thread_root_id=thread_root_id, + prompt=content_body, + ) + ) + except Exception as e: + logger.error(e, exe_info=True) + + # common command + # !gpt command if ( self.openai_api_key is not None @@ -1300,6 +1347,37 @@ async def to_device_callback(self, event: KeyVerificationEvent) -> None: estr = traceback.format_exc() logger.info(estr) + # thread chat + async def thread_chat( + self, room_id, reply_to_event_id, thread_root_id, prompt, sender_id + ): + try: + await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000) + content = await self.chatbot.ask_async_v2( + prompt=prompt, + convo_id=thread_root_id, + ) + await send_room_message( + self.client, + room_id, + reply_message=content, + reply_to_event_id=reply_to_event_id, + sender_id=sender_id, + reply_in_thread=True, + thread_root_id=thread_root_id, + ) + except Exception as e: + logger.error(e, exe_info=True) + await send_room_message( + self.client, + room_id, + reply_message=GENERAL_ERROR_MESSAGE, + sender_id=sender_id, + reply_to_event_id=reply_to_event_id, + reply_in_thread=True, + thread_root_id=thread_root_id, + ) + # !chat command async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message): try: diff --git a/src/gptbot.py b/src/gptbot.py index c9cfed4..88bde58 100644 --- a/src/gptbot.py +++ b/src/gptbot.py @@ -2,6 +2,7 @@ Code derived from https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py A simple wrapper for the official ChatGPT API """ +import sqlite3 import json from typing import AsyncGenerator from tenacity import retry, wait_random_exponential, stop_after_attempt @@ -9,16 +10,7 @@ import tiktoken -ENGINES = [ - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4", - "gpt-4-32k", - "gpt-4-0613", - "gpt-4-32k-0613", -] +ENGINES = ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-4-turbo"] class Chatbot: @@ -41,6 +33,7 @@ def __init__( reply_count: int = 1, truncate_limit: int = None, system_prompt: str = None, + db_path: str = "context.db", ) -> None: """ Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys) @@ -53,23 +46,24 @@ def __init__( or "You are ChatGPT, \ a large language model trained by OpenAI. Respond conversationally" ) + # https://platform.openai.com/docs/models self.max_tokens: int = max_tokens or ( - 31000 + 127000 + if "gpt-4-turbo" in engine + else 31000 if "gpt-4-32k" in engine else 7000 if "gpt-4" in engine - else 15000 - if "gpt-3.5-turbo-16k" in engine - else 4000 + else 16000 ) self.truncate_limit: int = truncate_limit or ( - 30500 + 126500 + if "gpt-4-turbo" in engine + else 30500 if "gpt-4-32k" in engine else 6500 if "gpt-4" in engine - else 14500 - if "gpt-3.5-turbo-16k" in engine - else 3500 + else 15500 ) self.temperature: float = temperature self.top_p: float = top_p @@ -80,17 +74,49 @@ def __init__( self.aclient = aclient - self.conversation: dict[str, list[dict]] = { + self.db_path = db_path + + self.conn = sqlite3.connect(self.db_path) + self.cursor = self.conn.cursor() + + self._create_tables() + + self.conversation = self._load_conversation() + + if self.get_token_count("default") > self.max_tokens: + raise Exception("System prompt is too long") + + def _create_tables(self) -> None: + self.conn.execute( + """ + CREATE TABLE IF NOT EXISTS conversations( + id INTEGER PRIMARY KEY AUTOINCREMENT, + convo_id TEXT UNIQUE, + messages TEXT + ) + """ + ) + + def _load_conversation(self) -> dict[str, list[dict]]: + conversations: dict[str, list[dict]] = { "default": [ { "role": "system", - "content": system_prompt, + "content": self.system_prompt, }, ], } + self.cursor.execute("SELECT convo_id, messages FROM conversations") + for convo_id, messages in self.cursor.fetchall(): + conversations[convo_id] = json.loads(messages) + return conversations - if self.get_token_count("default") > self.max_tokens: - raise Exception("System prompt is too long") + def _save_conversation(self, convo_id) -> None: + self.conn.execute( + "INSERT OR REPLACE INTO conversations (convo_id, messages) VALUES (?, ?)", + (convo_id, json.dumps(self.conversation[convo_id])), + ) + self.conn.commit() def add_to_conversation( self, @@ -102,6 +128,7 @@ def add_to_conversation( Add a message to the conversation """ self.conversation[convo_id].append({"role": role, "content": message}) + self._save_conversation(convo_id) def __truncate_conversation(self, convo_id: str = "default") -> None: """ @@ -116,6 +143,7 @@ def __truncate_conversation(self, convo_id: str = "default") -> None: self.conversation[convo_id].pop(1) else: break + self._save_conversation(convo_id) # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def get_token_count(self, convo_id: str = "default") -> int: @@ -305,6 +333,7 @@ def reset(self, convo_id: str = "default", system_prompt: str = None) -> None: self.conversation[convo_id] = [ {"role": "system", "content": system_prompt or self.system_prompt}, ] + self._save_conversation(convo_id) @retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(3)) async def oneTimeAsk( diff --git a/src/send_message.py b/src/send_message.py index 26179d6..7ee5b60 100644 --- a/src/send_message.py +++ b/src/send_message.py @@ -12,6 +12,8 @@ async def send_room_message( sender_id: str = "", user_message: str = "", reply_to_event_id: str = "", + reply_in_thread: bool = False, + thread_root_id: str = "", ) -> None: if reply_to_event_id == "": content = { @@ -23,6 +25,23 @@ async def send_room_message( extensions=["nl2br", "tables", "fenced_code"], ), } + elif reply_in_thread and thread_root_id: + content = { + "msgtype": "m.text", + "body": reply_message, + "format": "org.matrix.custom.html", + "formatted_body": markdown.markdown( + reply_message, + extensions=["nl2br", "tables", "fenced_code"], + ), + "m.relates_to": { + "m.in_reply_to": {"event_id": reply_to_event_id}, + "rel_type": "m.thread", + "event_id": thread_root_id, + "is_falling_back": True, + }, + } + else: body = "> <" + sender_id + "> " + user_message + "\n\n" + reply_message format = r"org.matrix.custom.html" @@ -51,13 +70,11 @@ async def send_room_message( "formatted_body": formatted_body, "m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}}, } - try: - await client.room_send( - room_id, - message_type="m.room.message", - content=content, - ignore_unverified_devices=True, - ) - await client.room_typing(room_id, typing_state=False) - except Exception as e: - logger.error(e) + + await client.room_send( + room_id, + message_type="m.room.message", + content=content, + ignore_unverified_devices=True, + ) + await client.room_typing(room_id, typing_state=False)