Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove strict validate and make claude model available #372

Merged
merged 11 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 15 additions & 33 deletions llmkira/extra/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from llmkira.sdk.func_calling import ToolRegister

from .client import UserCostClient, UserConfigClient, UserCost, UserConfig
from .schema import UserDriverMode
from ...sdk.endpoint import Driver


Expand All @@ -23,7 +22,6 @@ def is_valid_url(url):


class CostControl(object):

@staticmethod
async def add_cost(cost: UserCost):
"""
Expand All @@ -44,9 +42,7 @@ def get_model():
return SCHEMA_GROUP.get_model_list()

@staticmethod
async def get_driver_config(
uid: str
) -> UserConfig.LlmConfig:
async def get_driver_config(uid: str) -> UserConfig.LlmConfig:
"""
:param uid: user id
:return: UserConfig.LlmConfig(Token/Driver)
Expand All @@ -57,10 +53,7 @@ async def get_driver_config(
return _user_data.llm_driver

@staticmethod
def uid_make(
platform: str,
user_id: Union[str, int]
):
def uid_make(platform: str, user_id: Union[str, int]):
"""
:param platform: platform.
:param user_id: user id.
Expand All @@ -74,9 +67,7 @@ def uid_make(
return f"{platform}:{user_id}"

@staticmethod
async def clear_endpoint(
uid: str
):
async def clear_endpoint(uid: str):
"""
:param uid: user id
:return: bool
Expand All @@ -90,11 +81,11 @@ async def clear_endpoint(

@staticmethod
async def set_endpoint(
uid: str,
api_key: str,
endpoint: str = None,
model: str = None,
org_id: str = None
uid: str,
api_key: str,
endpoint: str = None,
model: str = None,
org_id: str = None,
) -> Driver:
"""
:param uid: user id
Expand All @@ -106,26 +97,23 @@ async def set_endpoint(
:return: new_driver
"""
# assert model in MODEL.__args__, f"openai model is not valid,must be one of {MODEL.__args__}"
if model not in UserControl.get_model():
model = UserControl.get_model()[0]
_user_data = await UserConfigClient().read_by_uid(uid=uid)
_user_data = _user_data or UserConfig(uid=uid)
new_driver = Driver(endpoint=endpoint, api_key=api_key, model=model, org_id=org_id)
new_driver = Driver(
endpoint=endpoint, api_key=api_key, model=model, org_id=org_id
)
_user_data.llm_driver.driver = new_driver
await UserConfigClient().update(uid=uid, data=_user_data)
return new_driver

@staticmethod
async def block_plugin(
uid: str,
plugin_name: str
) -> list:
async def block_plugin(uid: str, plugin_name: str) -> list:
"""
:param uid: user id
:param plugin_name: plugin name
:return: list
"""
if not (plugin_name in ToolRegister().functions):
if plugin_name not in ToolRegister().functions:
raise ValueError(f"plugin {plugin_name} is not exist :(")
_user_data = await UserConfigClient().read_by_uid(uid=uid)
_user_data = _user_data or UserConfig(uid=uid)
Expand All @@ -134,10 +122,7 @@ async def block_plugin(
return _user_data.plugin_subs.block_list

@staticmethod
async def unblock_plugin(
uid: str,
plugin_name: str
) -> list:
async def unblock_plugin(uid: str, plugin_name: str) -> list:
"""
:param uid: user id
:param plugin_name: plugin name
Expand All @@ -150,10 +135,7 @@ async def unblock_plugin(
return _user_data.plugin_subs.block_list

@staticmethod
async def set_token(
uid: str,
token: Optional[str] = None
) -> Optional[str]:
async def set_token(uid: str, token: Optional[str] = None) -> Optional[str]:
"""
:param uid: user id
:param token: bind token
Expand Down
98 changes: 54 additions & 44 deletions llmkira/extra/user/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,51 @@ class UserDriverMode(Enum):
"""代理公共环境变量,也就是额外的token计费系统控制了公共环境变量的使用"""


class Cost(BaseModel):
"""消费记录细节"""

cost_by: str = Field("chat", description="环节")
token_usage: int = Field(0)
token_uuid: Optional[str] = Field(None, description="Api Key 的 hash")
llm_model: Optional[str] = Field(None, description="Model Name")
provide_type: int = Field(None, description="认证模式")

@classmethod
def by_function(
cls,
function_name: str,
token_usage: int,
token_uuid: str,
llm_model: str,
):
return cls(
cost_by=function_name,
token_usage=token_usage,
token_uuid=token_uuid,
model_name=llm_model,
)


# 基本单元

class UserCost(BaseModel):
"""用户消费记录
"""

class Cost(BaseModel):
"""消费记录细节
"""
cost_by: str = Field("chat", description="环节")
token_usage: int = Field(0)
token_uuid: Optional[str] = Field(None, description="Api Key 的 hash")
llm_model: Optional[str] = Field(None, description="Model Name")
provide_type: int = Field(None, description="认证模式")
class UserCost(BaseModel):
"""用户消费记录"""

@classmethod
def by_function(cls, function_name: str,
token_usage: int,
token_uuid: str,
llm_model: str,
):
return cls(cost_by=function_name, token_usage=token_usage, token_uuid=token_uuid, model_name=llm_model)

request_id: str = Field(default=None, description="请求 UUID")
uid: str = Field(default=None, description="用户 UID ,注意是平台+用户")
cost: Cost = Field(default=None, description="消费记录")
cost_time: int = Field(default=None, description="消费时间")
meta: dict = Field(default={}, description="元数据")

@classmethod
def create_from_function(
cls,
uid: str,
request_id: str,
cost_by: str,
token_usage: int,
token_uuid: str,
model_name: str,
cls,
uid: str,
request_id: str,
cost_by: str,
token_usage: int,
token_uuid: str,
model_name: str,
):
return cls(
request_id=request_id,
Expand All @@ -78,23 +85,22 @@ def create_from_function(

@classmethod
def create_from_task(
cls,
uid: str,
request_id: str,
cost: Cost,
cls,
uid: str,
cost: Cost,
):
return cls(
request_id=request_id,
uid=uid,
cost=cost,
cost_time=int(time.time()),
)

model_config = ConfigDict(extra="ignore",
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True
)
model_config = ConfigDict(
extra="ignore",
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True,
)


class UserConfig(BaseSettings):
Expand All @@ -107,6 +113,7 @@ class LlmConfig(BaseModel):
driver 作为一个单例模式
其他 `公共授权` 组件!
"""

driver: Optional[Driver] = Field(None, description="私有端点配置")
token: Optional[str] = Field(None, description="代理认证系统的token")
provider: Optional[str] = Field(None, description="认证平台")
Expand Down Expand Up @@ -158,7 +165,9 @@ def unblock(self, plugin_name: str) -> "UserConfig.PluginConfig":
created_time: int = Field(default=int(time.time()), description="创建时间")
last_use_time: int = Field(default=int(time.time()), description="最后使用时间")
uid: Union[str, int] = Field(None, description="用户UID")
plugin_subs: PluginConfig = Field(default_factory=PluginConfig.default, description="插件订阅")
plugin_subs: PluginConfig = Field(
default_factory=PluginConfig.default, description="插件订阅"
)
llm_driver: LlmConfig = Field(default_factory=LlmConfig.default, description="驱动")

@field_validator("uid")
Expand All @@ -167,9 +176,10 @@ def check_user_id(cls, v):
return str(v)
return v

model_config = SettingsConfigDict(extra="ignore",
frozen=True,
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True
)
model_config = SettingsConfigDict(
extra="ignore",
frozen=True,
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True,
)
35 changes: 20 additions & 15 deletions llmkira/middleware/llm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

from .llm_provider import GetAuthDriver
from ..extra.user import UserCost, CostControl
from ..extra.user.schema import Cost
from ..schema import RawMessage, Scraper
from ..sdk.adapter import SCHEMA_GROUP
from ..sdk.adapter import SCHEMA_GROUP, SingleModel
from ..sdk.endpoint import Driver
from ..sdk.endpoint.openai import Message
from ..sdk.endpoint.openai import Message, Openai, OpenaiResult
from ..sdk.endpoint.schema import LlmRequest, LlmResult
from ..sdk.memory.redis import RedisChatMessageHistory
from ..sdk.schema import ToolCallCompletion, SystemMessage, Function, Tool
Expand All @@ -23,6 +24,7 @@


class SystemPrompt(BaseModel):
# FIXME Deprecated
"""
系统提示
"""
Expand Down Expand Up @@ -153,6 +155,7 @@ def _append_function_tools(self, functions: List[Function]):
return self.tools

def scraper_create_message(self, write_back=True, system_prompt=True):
# FIXME Deprecated
"""
从人类消息和历史消息中构建请求所用消息
:param write_back: 是否写回,如果是 False,那么就不会写回到 Redis 数据库中,也就是重新请求
Expand Down Expand Up @@ -188,8 +191,6 @@ def scraper_create_message(self, write_back=True, system_prompt=True):
role="user",
)
_buffer.append(user_message)
# 装样子添加评分
# TODO 评分机制
for i, _msg in enumerate(_buffer):
self.scraper.add_message(_msg, score=len(str(_msg)) + 50)
# database:save redis
Expand All @@ -199,20 +200,29 @@ def scraper_create_message(self, write_back=True, system_prompt=True):
async def request_openai(
self,
auto_write_back: bool,
retrieve_mode: bool = False,
disable_function: bool = False,
) -> LlmResult:
"""
处理消息转换和调用工具
:param auto_write_back: 是否自动写回
:param disable_function: 禁用函数
:param retrieve_mode: 是否为检索模式,当我们需要重新处理超长消息时候,需要设定为 True
:return: OpenaiResult
"""
run_driver_model = (
self.driver.model if not retrieve_mode else self.driver.retrieve_model
)
run_driver_model = self.driver.model
endpoint_schema = self.get_schema(model_name=run_driver_model)
if not endpoint_schema:
logger.warning(
f"Openai model {run_driver_model} not found, use {run_driver_model} instead"
)
endpoint_schema = SingleModel(
llm_model=run_driver_model,
token_limit=8192,
request=Openai,
response=OpenaiResult,
schema_type="openai",
func_executor="tool_call",
exception=None,
)
# 添加函数定义的系统提示
if not disable_function:
for function_item in self.functions:
Expand All @@ -232,9 +242,6 @@ async def request_openai(
]
# 构建消息列表
self.scraper_create_message(write_back=auto_write_back)
# 折叠消息列表
if retrieve_mode:
self.scraper.fold_message()
# 削减消息列表
self.scraper.reduce_messages(
limit=endpoint_schema.token_limit, model_name=run_driver_model
Expand All @@ -253,7 +260,6 @@ async def request_openai(
f"\n--org {self.driver.org_id} "
f"\n--model {run_driver_model} "
f"\n--function {functions}"
f"\n--retrieve_mode {retrieve_mode}"
)
# 禁用函数?
if disable_function or not functions:
Expand Down Expand Up @@ -284,8 +290,7 @@ async def request_openai(
await CostControl.add_cost(
cost=UserCost.create_from_task(
uid=self.session_uid,
request_id=result.id,
cost=UserCost.Cost(
cost=Cost(
cost_by=self.task.receiver.platform,
token_usage=_usage,
token_uuid=self.driver.uuid,
Expand Down
Loading
Loading