Skip to content

Commit

Permalink
update telegram
Browse files Browse the repository at this point in the history
  • Loading branch information
sudoskys committed Apr 16, 2024
1 parent 6b5a535 commit 0f93169
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 77 deletions.
Empty file added app/components/__init__.py
Empty file.
64 changes: 64 additions & 0 deletions app/components/credential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from urllib.parse import urlparse

import requests
from pydantic import BaseModel


class ProviderError(Exception):
pass


class Credential(BaseModel):
api_key: str
api_endpoint: str
api_model: str

@classmethod
def from_provider(cls, token, provider_url):
"""
使用 token POST 请求 provider_url 获取用户信息
:param token: 用户 token
:param provider_url: provider url
:return: 用户信息
:raises HTTPError: 请求失败
:raises JSONDecodeError: 返回数据解析失败
:raises ProviderError: provider 返回错误信息
"""
response = requests.post(provider_url, data={"token": token})
response.raise_for_status()
user_data = response.json()
if user_data.get("error"):
raise ProviderError(user_data["error"])
return cls(
api_key=user_data["api_key"],
api_endpoint=user_data["api_endpoint"],
api_model=user_data["api_model"],
)


def split_setting_string(input_string):
if not isinstance(input_string, str):
return None
segments = input_string.split("$")

# 检查链接的有效性
def is_valid_url(url):
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False

# 开头为链接的情况
if is_valid_url(segments[0]) and len(segments) >= 3:
return segments[:3]
# 第二个元素为链接,第一个元素为字符串的情况
elif (
len(segments) == 2
and not is_valid_url(segments[0])
and is_valid_url(segments[1])
):
return segments
# 其他情况
else:
return None
93 changes: 93 additions & 0 deletions app/components/user_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/8 下午10:56
# @Author : sudoskys
# @File : __init__.py.py
# @Software: PyCharm
import time
from typing import Optional

from loguru import logger
from pydantic import BaseModel

from app.components.credential import Credential
from app.const import DBNAME
from llmkira.doc_manager import global_doc_client


class ChatCost(BaseModel):
user_id: str
cost_token: int = 0
endpoint: str = ""
cost_model: str = ""
produce_time: int = time.time()


class GenerateHistory(object):
def __init__(self, db_name: str = DBNAME, collection: str = "cost_history"):
""" """
self.client = global_doc_client.update_db_collection(
db_name=db_name, collection_name=collection
)

async def save(self, history: ChatCost):
return self.client.insert_one(history.model_dump(mode="json"))


class User(BaseModel):
user_id: str
last_use_time: int = time.time()
credential: Optional[Credential] = None


class UserManager(object):
def __init__(self, db_name: str = DBNAME, collection: str = "user"):
""" """
self.client = global_doc_client.update_db_collection(
db_name=db_name, collection_name=collection
)

async def read(self, user_id: str) -> User:
user_id = str(user_id)
database_read = self.client.find_one({"user_id": user_id})
if not database_read:
logger.info(f"Create new user: {user_id}")
return User(user_id=user_id)
# database_read.update({"user_id": user_id})
return User.model_validate(database_read)

async def save(self, user_model: User):
user_model = user_model.model_copy(update={"last_use_time": int(time.time())})
# 如果存在记录则更新
if self.client.find_one({"user_id": user_model.user_id}):
return self.client.update_one(
{"user_id": user_model.user_id},
{"$set": user_model.model_dump(mode="json")},
)
# 如果不存在记录则插入
else:
return self.client.insert_one(user_model.model_dump(mode="json"))


COST_MANAGER = GenerateHistory()
USER_MANAGER = UserManager()


async def record_cost(
user_id: str, cost_token: int, endpoint: str, cost_model: str, success: bool = True
):
try:
await COST_MANAGER.save(
ChatCost(
user_id=user_id,
produce_time=int(time.time()),
endpoint=endpoint,
cost_model=cost_model,
cost_token=cost_token if success else 0,
)
)
except Exception as exc:
logger.error(f"🔥 record_cost error: {exc}")


if __name__ == "__main__":
pass
File renamed without changes.
26 changes: 14 additions & 12 deletions llmkira/middleware/llm_task.py → app/middleware/llm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# @Time : 2023/8/18 上午9:37
# @Author : sudoskys
# @File : llm_task.py
import os
from typing import List, Optional

from loguru import logger
from pydantic import SecretStr

from app.components.credential import Credential
from app.components.user_manager import record_cost
from llmkira.kv_manager.instruction import InstructionManager
from llmkira.memory import global_message_runtime
from llmkira.openai.cell import Tool, Message, active_cell_string, SystemMessage
Expand Down Expand Up @@ -48,14 +49,13 @@ def __init__(
self.message_history = global_message_runtime.update_session(
session_id=session_uid
)
# TODO:实现用户配置读取

async def remember(self, *, message: Optional[Message] = None):
"""
写回消息到历史消息
"""
if message:
await self.message_history.append(message=message)
await self.message_history.append(messages=[message])

async def build_message(self, remember=True):
"""
Expand Down Expand Up @@ -87,18 +87,20 @@ async def build_message(self, remember=True):
user_message = message.format_user_message()
message_run.append(user_message)
if remember:
await self.message_history.append(message=user_message)
await self.message_history.append(messages=[user_message])
return message_run

async def request_openai(
self,
remember: bool,
credential: Credential,
disable_tool: bool = False,
) -> OpenAIResult:
"""
处理消息转换和调用工具
:param remember: 是否自动写回
:param disable_tool: 禁用函数
:param credential: 凭证
:return: OpenaiResult 返回结果
:raise RuntimeError: 无法处理消息
:raise AssertionError: 无法处理消息
Expand All @@ -113,7 +115,6 @@ async def request_openai(
messages.append(SystemMessage(content=self.task.task_sign.instruction))
messages.extend(await self.build_message(remember=remember))
# TODO:实现消息时序切片

# 日志
logger.info(
f"[x] Openai request" f"\n--message {messages} " f"\n--tools {tools}"
Expand All @@ -125,21 +126,22 @@ async def request_openai(
# 根据模型选择不同的驱动a
assert messages, RuntimeError("llm_task:message cant be none...")
endpoint: OpenAI = OpenAI(
messages=messages,
tools=tools,
model="gpt-3.5-turbo", # FIXME:从用户配置中获取
messages=messages, tools=tools, model=credential.api_model
)
# 调用Openai
result: OpenAIResult = await endpoint.request(
session=OpenAICredential(
api_key=SecretStr(
os.getenv("OPENAI_API_KEY", None)
), # FIXME:从用户配置中获取
base_url=os.getenv("OPENAI_API_ENDPOINT"), # FIXME:从用户配置中获取
api_key=SecretStr(credential.api_key), base_url=credential.api_endpoint
)
)
_message = result.default_message
_usage = result.usage.total_tokens
await record_cost(
cost_model=credential.api_model,
cost_token=_usage,
endpoint=credential.api_endpoint,
user_id=self.session_uid,
)
# 写回数据库
await self.remember(message=_message)
return result
67 changes: 41 additions & 26 deletions app/receiver/receiver_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from loguru import logger
from telebot import formatting

from app.components.credential import Credential
from app.components.user_manager import USER_MANAGER
from app.middleware.llm_task import OpenaiMiddleware
from llmkira.kv_manager.env import EnvManager
from llmkira.middleware.llm_task import OpenaiMiddleware
from llmkira.openai import OpenaiError
from llmkira.openai.cell import ToolCall, Message, Tool
from llmkira.openai.request import OpenAIResult
Expand All @@ -31,6 +33,11 @@
from llmkira.task.snapshot import global_snapshot_storage


async def read_user_credential(user_id: str) -> Optional[Credential]:
user = await USER_MANAGER.read(user_id=user_id)
return user.credential


async def generate_authorization(
secrets: Dict, tool_invocation: ToolCall
) -> Tuple[dict, list, bool]:
Expand Down Expand Up @@ -147,7 +154,7 @@ async def forward(self, receiver: Location, message: list):

@abstractmethod
async def reply(
self, receiver: Location, message: Message, reply_to_message: bool = True
self, receiver: Location, messages: List[Message], reply_to_message: bool = True
):
"""
模型直转发,Message是Openai的类型
Expand Down Expand Up @@ -232,21 +239,28 @@ async def _flash(
"""
try:
try:
credentials = await read_user_credential(user_id=task.receiver.uid)
assert credentials, "You need to /login first"
llm_result = await llm.request_openai(
remember=remember,
disable_tool=disable_tool,
credential=credentials,
)
assistant_message = llm_result.default_message
logger.debug(f"Assistant:{assistant_message}")
except OpenaiError as exc:
await self.sender.error(receiver=task.receiver, text=exc.message)
return exc
except (RuntimeError, AssertionError) as exc:
except RuntimeError as exc:
logger.exception(exc)
await self.sender.error(
receiver=task.receiver,
text="Can't get message validate from your history",
)
return exc
except AssertionError as exc:
await self.sender.error(receiver=task.receiver, text=str(exc))
return exc
except Exception as exc:
logger.exception(exc)
await self.sender.error(
Expand All @@ -269,7 +283,7 @@ async def _flash(
)
return logger.debug("Function loop ended")
return await self.sender.reply(
receiver=task.receiver, message=assistant_message
receiver=task.receiver, messages=[assistant_message]
)
except Exception as e:
raise e
Expand Down Expand Up @@ -364,30 +378,31 @@ async def on_message(self, message: AbstractIncomingMessage):
snap_data = await global_snapshot_storage.read(
user_id=task_head.receiver.uid
)
data = snap_data.data
renew_snap_data = []
for task in data:
if not task.snapshot_credential and not task.processed:
try:
await Task.create_and_send(
queue_name=task.channel, task=task.snapshot_data
)
except Exception as e:
logger.exception(f"Response to snapshot error {e}")
if snap_data is not None:
data = snap_data.data
renew_snap_data = []
for task in data:
if not task.snapshot_credential and not task.processed:
try:
await Task.create_and_send(
queue_name=task.channel, task=task.snapshot_data
)
except Exception as e:
logger.exception(f"Response to snapshot error {e}")
else:
logger.info(
f"🧀 Response to snapshot {task.snap_uuid} at {router}"
)
finally:
task.processed_at = int(time.time())
renew_snap_data.append(task)
else:
logger.info(
f"🧀 Response to snapshot {task.snap_uuid} at {router}"
)
finally:
task.processed_at = int(time.time())
task.processed_at = None
renew_snap_data.append(task)
else:
task.processed_at = None
renew_snap_data.append(task)
snap_data.data = renew_snap_data
await global_snapshot_storage.write(
user_id=task_head.receiver.uid, snapshot=snap_data
)
snap_data.data = renew_snap_data
await global_snapshot_storage.write(
user_id=task_head.receiver.uid, snapshot=snap_data
)
except Exception as e:
logger.exception(e)
await message.reject(requeue=False)
Expand Down
2 changes: 1 addition & 1 deletion app/receiver/telegram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from app.receiver.receiver_client import BaseReceiver, BaseSender
from app.setting.telegram import BotSetting
from llmkira.kv_manager.file import File
from llmkira.middleware.llm_task import OpenaiMiddleware
from app.middleware.llm_task import OpenaiMiddleware
from llmkira.openai.cell import Message
from llmkira.openai.request import OpenAIResult
from llmkira.task import Task, TaskHeader
Expand Down
Loading

0 comments on commit 0f93169

Please sign in to comment.