From 0f93169fe9b9e862a5431acaac1401eb56d70833 Mon Sep 17 00:00:00 2001 From: sudoskys Date: Wed, 17 Apr 2024 01:12:18 +0800 Subject: [PATCH] update telegram --- app/components/__init__.py | 0 app/components/credential.py | 64 +++++++++++++++++ app/components/user_manager/__init__.py | 93 +++++++++++++++++++++++++ {llmkira => app}/middleware/__init__.py | 0 {llmkira => app}/middleware/llm_task.py | 26 +++---- app/receiver/receiver_client.py | 67 +++++++++++------- app/receiver/telegram/__init__.py | 2 +- app/sender/app.py | 3 +- app/sender/telegram/__init__.py | 92 +++++++++++++++++------- app/sender/util_func.py | 37 +++++++--- 10 files changed, 307 insertions(+), 77 deletions(-) create mode 100644 app/components/__init__.py create mode 100644 app/components/credential.py create mode 100644 app/components/user_manager/__init__.py rename {llmkira => app}/middleware/__init__.py (100%) rename {llmkira => app}/middleware/llm_task.py (87%) diff --git a/app/components/__init__.py b/app/components/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/components/credential.py b/app/components/credential.py new file mode 100644 index 000000000..22961a5de --- /dev/null +++ b/app/components/credential.py @@ -0,0 +1,64 @@ +from urllib.parse import urlparse + +import requests +from pydantic import BaseModel + + +class ProviderError(Exception): + pass + + +class Credential(BaseModel): + api_key: str + api_endpoint: str + api_model: str + + @classmethod + def from_provider(cls, token, provider_url): + """ + 使用 token POST 请求 provider_url 获取用户信息 + :param token: 用户 token + :param provider_url: provider url + :return: 用户信息 + :raises HTTPError: 请求失败 + :raises JSONDecodeError: 返回数据解析失败 + :raises ProviderError: provider 返回错误信息 + """ + response = requests.post(provider_url, data={"token": token}) + response.raise_for_status() + user_data = response.json() + if user_data.get("error"): + raise ProviderError(user_data["error"]) + return cls( + api_key=user_data["api_key"], + api_endpoint=user_data["api_endpoint"], + api_model=user_data["api_model"], + ) + + +def split_setting_string(input_string): + if not isinstance(input_string, str): + return None + segments = input_string.split("$") + + # 检查链接的有效性 + def is_valid_url(url): + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except ValueError: + return False + + # 开头为链接的情况 + if is_valid_url(segments[0]) and len(segments) >= 3: + return segments[:3] + # 第二个元素为链接,第一个元素为字符串的情况 + elif ( + len(segments) == 2 + and not is_valid_url(segments[0]) + and is_valid_url(segments[1]) + ): + return segments + # 其他情况 + else: + return None diff --git a/app/components/user_manager/__init__.py b/app/components/user_manager/__init__.py new file mode 100644 index 000000000..8969428a9 --- /dev/null +++ b/app/components/user_manager/__init__.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/8 下午10:56 +# @Author : sudoskys +# @File : __init__.py.py +# @Software: PyCharm +import time +from typing import Optional + +from loguru import logger +from pydantic import BaseModel + +from app.components.credential import Credential +from app.const import DBNAME +from llmkira.doc_manager import global_doc_client + + +class ChatCost(BaseModel): + user_id: str + cost_token: int = 0 + endpoint: str = "" + cost_model: str = "" + produce_time: int = time.time() + + +class GenerateHistory(object): + def __init__(self, db_name: str = DBNAME, collection: str = "cost_history"): + """ """ + self.client = global_doc_client.update_db_collection( + db_name=db_name, collection_name=collection + ) + + async def save(self, history: ChatCost): + return self.client.insert_one(history.model_dump(mode="json")) + + +class User(BaseModel): + user_id: str + last_use_time: int = time.time() + credential: Optional[Credential] = None + + +class UserManager(object): + def __init__(self, db_name: str = DBNAME, collection: str = "user"): + """ """ + self.client = global_doc_client.update_db_collection( + db_name=db_name, collection_name=collection + ) + + async def read(self, user_id: str) -> User: + user_id = str(user_id) + database_read = self.client.find_one({"user_id": user_id}) + if not database_read: + logger.info(f"Create new user: {user_id}") + return User(user_id=user_id) + # database_read.update({"user_id": user_id}) + return User.model_validate(database_read) + + async def save(self, user_model: User): + user_model = user_model.model_copy(update={"last_use_time": int(time.time())}) + # 如果存在记录则更新 + if self.client.find_one({"user_id": user_model.user_id}): + return self.client.update_one( + {"user_id": user_model.user_id}, + {"$set": user_model.model_dump(mode="json")}, + ) + # 如果不存在记录则插入 + else: + return self.client.insert_one(user_model.model_dump(mode="json")) + + +COST_MANAGER = GenerateHistory() +USER_MANAGER = UserManager() + + +async def record_cost( + user_id: str, cost_token: int, endpoint: str, cost_model: str, success: bool = True +): + try: + await COST_MANAGER.save( + ChatCost( + user_id=user_id, + produce_time=int(time.time()), + endpoint=endpoint, + cost_model=cost_model, + cost_token=cost_token if success else 0, + ) + ) + except Exception as exc: + logger.error(f"🔥 record_cost error: {exc}") + + +if __name__ == "__main__": + pass diff --git a/llmkira/middleware/__init__.py b/app/middleware/__init__.py similarity index 100% rename from llmkira/middleware/__init__.py rename to app/middleware/__init__.py diff --git a/llmkira/middleware/llm_task.py b/app/middleware/llm_task.py similarity index 87% rename from llmkira/middleware/llm_task.py rename to app/middleware/llm_task.py index a20592c98..5a66c88c9 100644 --- a/llmkira/middleware/llm_task.py +++ b/app/middleware/llm_task.py @@ -2,12 +2,13 @@ # @Time : 2023/8/18 上午9:37 # @Author : sudoskys # @File : llm_task.py -import os from typing import List, Optional from loguru import logger from pydantic import SecretStr +from app.components.credential import Credential +from app.components.user_manager import record_cost from llmkira.kv_manager.instruction import InstructionManager from llmkira.memory import global_message_runtime from llmkira.openai.cell import Tool, Message, active_cell_string, SystemMessage @@ -48,14 +49,13 @@ def __init__( self.message_history = global_message_runtime.update_session( session_id=session_uid ) - # TODO:实现用户配置读取 async def remember(self, *, message: Optional[Message] = None): """ 写回消息到历史消息 """ if message: - await self.message_history.append(message=message) + await self.message_history.append(messages=[message]) async def build_message(self, remember=True): """ @@ -87,18 +87,20 @@ async def build_message(self, remember=True): user_message = message.format_user_message() message_run.append(user_message) if remember: - await self.message_history.append(message=user_message) + await self.message_history.append(messages=[user_message]) return message_run async def request_openai( self, remember: bool, + credential: Credential, disable_tool: bool = False, ) -> OpenAIResult: """ 处理消息转换和调用工具 :param remember: 是否自动写回 :param disable_tool: 禁用函数 + :param credential: 凭证 :return: OpenaiResult 返回结果 :raise RuntimeError: 无法处理消息 :raise AssertionError: 无法处理消息 @@ -113,7 +115,6 @@ async def request_openai( messages.append(SystemMessage(content=self.task.task_sign.instruction)) messages.extend(await self.build_message(remember=remember)) # TODO:实现消息时序切片 - # 日志 logger.info( f"[x] Openai request" f"\n--message {messages} " f"\n--tools {tools}" @@ -125,21 +126,22 @@ async def request_openai( # 根据模型选择不同的驱动a assert messages, RuntimeError("llm_task:message cant be none...") endpoint: OpenAI = OpenAI( - messages=messages, - tools=tools, - model="gpt-3.5-turbo", # FIXME:从用户配置中获取 + messages=messages, tools=tools, model=credential.api_model ) # 调用Openai result: OpenAIResult = await endpoint.request( session=OpenAICredential( - api_key=SecretStr( - os.getenv("OPENAI_API_KEY", None) - ), # FIXME:从用户配置中获取 - base_url=os.getenv("OPENAI_API_ENDPOINT"), # FIXME:从用户配置中获取 + api_key=SecretStr(credential.api_key), base_url=credential.api_endpoint ) ) _message = result.default_message _usage = result.usage.total_tokens + await record_cost( + cost_model=credential.api_model, + cost_token=_usage, + endpoint=credential.api_endpoint, + user_id=self.session_uid, + ) # 写回数据库 await self.remember(message=_message) return result diff --git a/app/receiver/receiver_client.py b/app/receiver/receiver_client.py index 64827d18a..e7123ec12 100644 --- a/app/receiver/receiver_client.py +++ b/app/receiver/receiver_client.py @@ -18,8 +18,10 @@ from loguru import logger from telebot import formatting +from app.components.credential import Credential +from app.components.user_manager import USER_MANAGER +from app.middleware.llm_task import OpenaiMiddleware from llmkira.kv_manager.env import EnvManager -from llmkira.middleware.llm_task import OpenaiMiddleware from llmkira.openai import OpenaiError from llmkira.openai.cell import ToolCall, Message, Tool from llmkira.openai.request import OpenAIResult @@ -31,6 +33,11 @@ from llmkira.task.snapshot import global_snapshot_storage +async def read_user_credential(user_id: str) -> Optional[Credential]: + user = await USER_MANAGER.read(user_id=user_id) + return user.credential + + async def generate_authorization( secrets: Dict, tool_invocation: ToolCall ) -> Tuple[dict, list, bool]: @@ -147,7 +154,7 @@ async def forward(self, receiver: Location, message: list): @abstractmethod async def reply( - self, receiver: Location, message: Message, reply_to_message: bool = True + self, receiver: Location, messages: List[Message], reply_to_message: bool = True ): """ 模型直转发,Message是Openai的类型 @@ -232,21 +239,28 @@ async def _flash( """ try: try: + credentials = await read_user_credential(user_id=task.receiver.uid) + assert credentials, "You need to /login first" llm_result = await llm.request_openai( remember=remember, disable_tool=disable_tool, + credential=credentials, ) assistant_message = llm_result.default_message logger.debug(f"Assistant:{assistant_message}") except OpenaiError as exc: await self.sender.error(receiver=task.receiver, text=exc.message) return exc - except (RuntimeError, AssertionError) as exc: + except RuntimeError as exc: + logger.exception(exc) await self.sender.error( receiver=task.receiver, text="Can't get message validate from your history", ) return exc + except AssertionError as exc: + await self.sender.error(receiver=task.receiver, text=str(exc)) + return exc except Exception as exc: logger.exception(exc) await self.sender.error( @@ -269,7 +283,7 @@ async def _flash( ) return logger.debug("Function loop ended") return await self.sender.reply( - receiver=task.receiver, message=assistant_message + receiver=task.receiver, messages=[assistant_message] ) except Exception as e: raise e @@ -364,30 +378,31 @@ async def on_message(self, message: AbstractIncomingMessage): snap_data = await global_snapshot_storage.read( user_id=task_head.receiver.uid ) - data = snap_data.data - renew_snap_data = [] - for task in data: - if not task.snapshot_credential and not task.processed: - try: - await Task.create_and_send( - queue_name=task.channel, task=task.snapshot_data - ) - except Exception as e: - logger.exception(f"Response to snapshot error {e}") + if snap_data is not None: + data = snap_data.data + renew_snap_data = [] + for task in data: + if not task.snapshot_credential and not task.processed: + try: + await Task.create_and_send( + queue_name=task.channel, task=task.snapshot_data + ) + except Exception as e: + logger.exception(f"Response to snapshot error {e}") + else: + logger.info( + f"🧀 Response to snapshot {task.snap_uuid} at {router}" + ) + finally: + task.processed_at = int(time.time()) + renew_snap_data.append(task) else: - logger.info( - f"🧀 Response to snapshot {task.snap_uuid} at {router}" - ) - finally: - task.processed_at = int(time.time()) + task.processed_at = None renew_snap_data.append(task) - else: - task.processed_at = None - renew_snap_data.append(task) - snap_data.data = renew_snap_data - await global_snapshot_storage.write( - user_id=task_head.receiver.uid, snapshot=snap_data - ) + snap_data.data = renew_snap_data + await global_snapshot_storage.write( + user_id=task_head.receiver.uid, snapshot=snap_data + ) except Exception as e: logger.exception(e) await message.reject(requeue=False) diff --git a/app/receiver/telegram/__init__.py b/app/receiver/telegram/__init__.py index f6e7a36dc..c3589239f 100644 --- a/app/receiver/telegram/__init__.py +++ b/app/receiver/telegram/__init__.py @@ -14,7 +14,7 @@ from app.receiver.receiver_client import BaseReceiver, BaseSender from app.setting.telegram import BotSetting from llmkira.kv_manager.file import File -from llmkira.middleware.llm_task import OpenaiMiddleware +from app.middleware.llm_task import OpenaiMiddleware from llmkira.openai.cell import Message from llmkira.openai.request import OpenAIResult from llmkira.task import Task, TaskHeader diff --git a/app/sender/app.py b/app/sender/app.py index 027cce69c..c13d906b6 100644 --- a/app/sender/app.py +++ b/app/sender/app.py @@ -7,6 +7,8 @@ from dotenv import load_dotenv from loguru import logger +from llmkira import load_from_entrypoint, get_entrypoint_plugins + load_dotenv() __area__ = "sender" @@ -15,7 +17,6 @@ def run(): import asyncio from llmkira import load_plugins - from llmkira.sdk import load_from_entrypoint, get_entrypoint_plugins from app.setting import PlatformSetting start_setting = PlatformSetting.from_subdir() diff --git a/app/sender/telegram/__init__.py b/app/sender/telegram/__init__.py index e0970452b..47bb14fb5 100644 --- a/app/sender/telegram/__init__.py +++ b/app/sender/telegram/__init__.py @@ -12,12 +12,15 @@ from telebot.async_telebot import AsyncTeleBot from telebot.asyncio_storage import StateMemoryStorage from telebot.formatting import escape_markdown +from telegramify_markdown import convert from app.sender.util_func import ( parse_command, is_command, is_empty_command, auth_reloader, + uid_make, + login, ) from app.setting.telegram import BotSetting from llmkira.kv_manager.env import EnvManager @@ -30,13 +33,12 @@ __sender__ = "telegram" __default_disable_tool_action__ = False -StepCache = StateMemoryStorage() -TelegramTask = Task(queue=__sender__) +from app.components.credential import split_setting_string, Credential, ProviderError +StepCache = StateMemoryStorage() -def uid_make(platform: str, user_id: int): - return f"{platform}:{user_id}" +TelegramTask = Task(queue=__sender__) class TelegramBotRunner(Runner): @@ -206,8 +208,53 @@ async def create_task(message: types.Message, disable_tool_action: bool = True): @bot.message_handler(commands="login", chat_types=["private"]) async def listen_login_command(message: types.Message): - logger.debug(f"Debug:login command {message}") - pass + logger.debug("Debug:login command") + _cmd, _arg = parse_command(command=message.text) + settings = split_setting_string(_arg) + if not settings: + return await bot.reply_to( + message, + text=convert( + "🔑 **Incorrect format.**\n" + "You can set it via `https://api.com/v1$key$model` format, " + "or you can log in via URL using `token$https://provider.com`." + ), + ) + if len(settings) == 2: + try: + credential = Credential.from_provider( + token=settings[0], provider_url=settings[1] + ) + except ProviderError as e: + return await bot.reply_to( + message, text=f"Login failed, website return {e}" + ) + except Exception as e: + logger.error(f"Login failed {e}") + return await bot.reply_to( + message, text=f"Login failed, because {type(e)}" + ) + else: + await login( + uid=uid_make(__sender__, message.from_user.id), + credential=credential, + ) + return await bot.reply_to( + message, text="Login success as provider! Welcome master!" + ) + elif len(settings) == 3: + credential = Credential( + api_endpoint=settings[0], api_key=settings[1], api_model=settings[2] + ) + await login( + uid=uid_make(__sender__, message.from_user.id), + credential=credential, + ) + return await bot.reply_to( + message, text=f"Login success as {settings[2]}! Welcome master! " + ) + else: + return logger.trace(f"Login failed {settings}") @bot.message_handler(commands="env", chat_types=["private"]) async def listen_env_command(message: types.Message): @@ -267,25 +314,14 @@ async def listen_help_command(message: types.Message): async def listen_tool_command(message: types.Message): _tool = ToolRegister().get_plugins_meta _paper = [ - [tool_item.name, tool_item.get_function_string, tool_item.usage] + f"# {tool_item.name}\n{tool_item.get_function_string}\n```{tool_item.usage}```" for tool_item in _tool ] - arg = [ - formatting.mbold(item[0]) - + "\n" - + formatting.mcode(item[1]) - + "\n" - + formatting.mitalic(item[2]) - + "\n" - for item in _paper - ] - reply_message_text = formatting.format_text( - formatting.mbold("🔧 Tool List"), *arg, separator="\n" - ) + reply_message_text = "\n".join(_paper) if len(reply_message_text) > 4096: reply_message_text = reply_message_text[:4096] return await bot.reply_to( - message, text=reply_message_text, parse_mode="MarkdownV2" + message, text=convert(reply_message_text), parse_mode="MarkdownV2" ) @bot.message_handler( @@ -297,16 +333,18 @@ async def listen_auth_command(message: types.Message): return None try: await auth_reloader( - uuid=_arg, user_id=f"{message.from_user.id}", platform=__sender__ + snapshot_credential=_arg, + user_id=f"{message.from_user.id}", + platform=__sender__, ) except Exception as e: auth_result = ( "❌ Auth failed,You dont have permission or the task do not exist" ) - logger.error(f"[270563]auth_reloader failed {e}") + logger.info(f"Auth failed {e}") else: - auth_result = "🪄 Auth Pass" - return await bot.reply_to(message, text=auth_result) + auth_result = "🪄 Snapshot released" + return await bot.reply_to(message, text=convert(auth_result)) @bot.message_handler( content_types=["text", "photo", "document"], chat_types=["private"] @@ -376,7 +414,7 @@ async def handle_group_msg(message: types.Message): at_bot_username=BotSetting.bot_username, ): if is_empty_command(text=message.text): - return await bot.reply_to(message, text="?") + return await bot.reply_to(message, text="Say something?") return await create_task( message, disable_tool_action=__default_disable_tool_action__ ) @@ -386,7 +424,7 @@ async def handle_group_msg(message: types.Message): at_bot_username=BotSetting.bot_username, ): if is_empty_command(text=message.text): - return await bot.reply_to(message, text="?") + return await bot.reply_to(message, text="Say something?") return await create_task(message, disable_tool_action=False) if is_command( text=message.text, @@ -394,7 +432,7 @@ async def handle_group_msg(message: types.Message): at_bot_username=BotSetting.bot_username, ): if is_empty_command(text=message.text): - return await bot.reply_to(message, text="?") + return await bot.reply_to(message, text="Say something?") return await create_task(message, disable_tool_action=True) if f"@{BotSetting.bot_username} " in message.text or message.text.endswith( f" @{BotSetting.bot_username}" diff --git a/app/sender/util_func.py b/app/sender/util_func.py index 38032f42c..52cedaada 100644 --- a/app/sender/util_func.py +++ b/app/sender/util_func.py @@ -3,13 +3,25 @@ # @Author : sudoskys # @File : util_func.py # @Software: PyCharm -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from urllib.parse import urlparse from loguru import logger -from llmkira.middleware.chain_box import Chain, AuthReloader +from app.components.credential import Credential +from app.components.user_manager import USER_MANAGER from llmkira.task import Task +from llmkira.task.snapshot import SnapData, global_snapshot_storage + + +def uid_make(platform: str, user_id: Union[int, str]): + return f"{platform}:{user_id}" + + +async def login(uid, credential: Credential): + user = await USER_MANAGER.read(user_id=uid) + user.credential = credential + await USER_MANAGER.save(user_model=user) def parse_command(command: str) -> Tuple[Optional[str], Optional[str]]: @@ -74,7 +86,7 @@ def is_empty_command(text: str) -> bool: return False -async def auth_reloader(uuid: str, platform: str, user_id: str) -> None: +async def auth_reloader(snapshot_credential: str, platform: str, user_id: str) -> None: """ :param uuid: verify id :param platform: message channel @@ -82,13 +94,18 @@ async def auth_reloader(uuid: str, platform: str, user_id: str) -> None: :raise LookupError Not Found :return None """ - assert isinstance(uuid, str), "`uuid` Must Be Str" + assert isinstance(snapshot_credential, str), "`uuid` Must Be Str" assert isinstance(platform, str), "`platform` Must Be Str" assert isinstance(user_id, str), "`user_id` Must Be Str" - chain: Chain = await AuthReloader.from_form( - platform=platform, user_id=user_id - ).read_auth(uuid=uuid) - if not chain: + snap_shot: SnapData = await global_snapshot_storage.read( + user_id=uid_make(platform, user_id) + ) + if not snap_shot.data: raise LookupError("Auth Task Not Found") - logger.info(f"Auth Task Sent --task uuid {uuid} --user {user_id}") - await Task(queue=chain.channel).send_task(task=chain.arg) + logger.info(f"Snap Auth:{snapshot_credential},by user {user_id}") + for snap in snap_shot.data: + if snap.snapshot_credential == snapshot_credential: + return await Task.create_and_send( + queue_name=snap.channel, + task=snap.snapshot_data, + )