Skip to content

Commit

Permalink
Merge pull request #372 from LlmKira/dev
Browse files Browse the repository at this point in the history
Remove strict validate and make claude model available
  • Loading branch information
sudoskys authored Apr 13, 2024
2 parents c594aa2 + 345e57e commit 3c3709c
Show file tree
Hide file tree
Showing 14 changed files with 1,503 additions and 1,789 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
python-version: [ '3.9', '3.10', '3.11' ]
os: [ ubuntu-latest ] # , windows-latest, macos-latest ]

steps:
Expand All @@ -26,7 +26,7 @@ jobs:

- name: Install dependencies
run: |
pdm install --no-lock -G testing
pdm install --frozen-lockfile -G bot
- name: Run Tests
run: |
pdm run -v pytest tests
pdm run -v pytest tests
48 changes: 15 additions & 33 deletions llmkira/extra/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from llmkira.sdk.func_calling import ToolRegister

from .client import UserCostClient, UserConfigClient, UserCost, UserConfig
from .schema import UserDriverMode
from ...sdk.endpoint import Driver


Expand All @@ -23,7 +22,6 @@ def is_valid_url(url):


class CostControl(object):

@staticmethod
async def add_cost(cost: UserCost):
"""
Expand All @@ -44,9 +42,7 @@ def get_model():
return SCHEMA_GROUP.get_model_list()

@staticmethod
async def get_driver_config(
uid: str
) -> UserConfig.LlmConfig:
async def get_driver_config(uid: str) -> UserConfig.LlmConfig:
"""
:param uid: user id
:return: UserConfig.LlmConfig(Token/Driver)
Expand All @@ -57,10 +53,7 @@ async def get_driver_config(
return _user_data.llm_driver

@staticmethod
def uid_make(
platform: str,
user_id: Union[str, int]
):
def uid_make(platform: str, user_id: Union[str, int]):
"""
:param platform: platform.
:param user_id: user id.
Expand All @@ -74,9 +67,7 @@ def uid_make(
return f"{platform}:{user_id}"

@staticmethod
async def clear_endpoint(
uid: str
):
async def clear_endpoint(uid: str):
"""
:param uid: user id
:return: bool
Expand All @@ -90,11 +81,11 @@ async def clear_endpoint(

@staticmethod
async def set_endpoint(
uid: str,
api_key: str,
endpoint: str = None,
model: str = None,
org_id: str = None
uid: str,
api_key: str,
endpoint: str = None,
model: str = None,
org_id: str = None,
) -> Driver:
"""
:param uid: user id
Expand All @@ -106,26 +97,23 @@ async def set_endpoint(
:return: new_driver
"""
# assert model in MODEL.__args__, f"openai model is not valid,must be one of {MODEL.__args__}"
if model not in UserControl.get_model():
model = UserControl.get_model()[0]
_user_data = await UserConfigClient().read_by_uid(uid=uid)
_user_data = _user_data or UserConfig(uid=uid)
new_driver = Driver(endpoint=endpoint, api_key=api_key, model=model, org_id=org_id)
new_driver = Driver(
endpoint=endpoint, api_key=api_key, model=model, org_id=org_id
)
_user_data.llm_driver.driver = new_driver
await UserConfigClient().update(uid=uid, data=_user_data)
return new_driver

@staticmethod
async def block_plugin(
uid: str,
plugin_name: str
) -> list:
async def block_plugin(uid: str, plugin_name: str) -> list:
"""
:param uid: user id
:param plugin_name: plugin name
:return: list
"""
if not (plugin_name in ToolRegister().functions):
if plugin_name not in ToolRegister().functions:
raise ValueError(f"plugin {plugin_name} is not exist :(")
_user_data = await UserConfigClient().read_by_uid(uid=uid)
_user_data = _user_data or UserConfig(uid=uid)
Expand All @@ -134,10 +122,7 @@ async def block_plugin(
return _user_data.plugin_subs.block_list

@staticmethod
async def unblock_plugin(
uid: str,
plugin_name: str
) -> list:
async def unblock_plugin(uid: str, plugin_name: str) -> list:
"""
:param uid: user id
:param plugin_name: plugin name
Expand All @@ -150,10 +135,7 @@ async def unblock_plugin(
return _user_data.plugin_subs.block_list

@staticmethod
async def set_token(
uid: str,
token: Optional[str] = None
) -> Optional[str]:
async def set_token(uid: str, token: Optional[str] = None) -> Optional[str]:
"""
:param uid: user id
:param token: bind token
Expand Down
98 changes: 54 additions & 44 deletions llmkira/extra/user/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,51 @@ class UserDriverMode(Enum):
"""代理公共环境变量,也就是额外的token计费系统控制了公共环境变量的使用"""


class Cost(BaseModel):
"""消费记录细节"""

cost_by: str = Field("chat", description="环节")
token_usage: int = Field(0)
token_uuid: Optional[str] = Field(None, description="Api Key 的 hash")
llm_model: Optional[str] = Field(None, description="Model Name")
provide_type: int = Field(None, description="认证模式")

@classmethod
def by_function(
cls,
function_name: str,
token_usage: int,
token_uuid: str,
llm_model: str,
):
return cls(
cost_by=function_name,
token_usage=token_usage,
token_uuid=token_uuid,
model_name=llm_model,
)


# 基本单元

class UserCost(BaseModel):
"""用户消费记录
"""

class Cost(BaseModel):
"""消费记录细节
"""
cost_by: str = Field("chat", description="环节")
token_usage: int = Field(0)
token_uuid: Optional[str] = Field(None, description="Api Key 的 hash")
llm_model: Optional[str] = Field(None, description="Model Name")
provide_type: int = Field(None, description="认证模式")
class UserCost(BaseModel):
"""用户消费记录"""

@classmethod
def by_function(cls, function_name: str,
token_usage: int,
token_uuid: str,
llm_model: str,
):
return cls(cost_by=function_name, token_usage=token_usage, token_uuid=token_uuid, model_name=llm_model)

request_id: str = Field(default=None, description="请求 UUID")
uid: str = Field(default=None, description="用户 UID ,注意是平台+用户")
cost: Cost = Field(default=None, description="消费记录")
cost_time: int = Field(default=None, description="消费时间")
meta: dict = Field(default={}, description="元数据")

@classmethod
def create_from_function(
cls,
uid: str,
request_id: str,
cost_by: str,
token_usage: int,
token_uuid: str,
model_name: str,
cls,
uid: str,
request_id: str,
cost_by: str,
token_usage: int,
token_uuid: str,
model_name: str,
):
return cls(
request_id=request_id,
Expand All @@ -78,23 +85,22 @@ def create_from_function(

@classmethod
def create_from_task(
cls,
uid: str,
request_id: str,
cost: Cost,
cls,
uid: str,
cost: Cost,
):
return cls(
request_id=request_id,
uid=uid,
cost=cost,
cost_time=int(time.time()),
)

model_config = ConfigDict(extra="ignore",
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True
)
model_config = ConfigDict(
extra="ignore",
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True,
)


class UserConfig(BaseSettings):
Expand All @@ -107,6 +113,7 @@ class LlmConfig(BaseModel):
driver 作为一个单例模式
其他 `公共授权` 组件!
"""

driver: Optional[Driver] = Field(None, description="私有端点配置")
token: Optional[str] = Field(None, description="代理认证系统的token")
provider: Optional[str] = Field(None, description="认证平台")
Expand Down Expand Up @@ -158,7 +165,9 @@ def unblock(self, plugin_name: str) -> "UserConfig.PluginConfig":
created_time: int = Field(default=int(time.time()), description="创建时间")
last_use_time: int = Field(default=int(time.time()), description="最后使用时间")
uid: Union[str, int] = Field(None, description="用户UID")
plugin_subs: PluginConfig = Field(default_factory=PluginConfig.default, description="插件订阅")
plugin_subs: PluginConfig = Field(
default_factory=PluginConfig.default, description="插件订阅"
)
llm_driver: LlmConfig = Field(default_factory=LlmConfig.default, description="驱动")

@field_validator("uid")
Expand All @@ -167,9 +176,10 @@ def check_user_id(cls, v):
return str(v)
return v

model_config = SettingsConfigDict(extra="ignore",
frozen=True,
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True
)
model_config = SettingsConfigDict(
extra="ignore",
frozen=True,
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True,
)
35 changes: 20 additions & 15 deletions llmkira/middleware/llm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

from .llm_provider import GetAuthDriver
from ..extra.user import UserCost, CostControl
from ..extra.user.schema import Cost
from ..schema import RawMessage, Scraper
from ..sdk.adapter import SCHEMA_GROUP
from ..sdk.adapter import SCHEMA_GROUP, SingleModel
from ..sdk.endpoint import Driver
from ..sdk.endpoint.openai import Message
from ..sdk.endpoint.openai import Message, Openai, OpenaiResult
from ..sdk.endpoint.schema import LlmRequest, LlmResult
from ..sdk.memory.redis import RedisChatMessageHistory
from ..sdk.schema import ToolCallCompletion, SystemMessage, Function, Tool
Expand All @@ -23,6 +24,7 @@


class SystemPrompt(BaseModel):
# FIXME Deprecated
"""
系统提示
"""
Expand Down Expand Up @@ -153,6 +155,7 @@ def _append_function_tools(self, functions: List[Function]):
return self.tools

def scraper_create_message(self, write_back=True, system_prompt=True):
# FIXME Deprecated
"""
从人类消息和历史消息中构建请求所用消息
:param write_back: 是否写回,如果是 False,那么就不会写回到 Redis 数据库中,也就是重新请求
Expand Down Expand Up @@ -188,8 +191,6 @@ def scraper_create_message(self, write_back=True, system_prompt=True):
role="user",
)
_buffer.append(user_message)
# 装样子添加评分
# TODO 评分机制
for i, _msg in enumerate(_buffer):
self.scraper.add_message(_msg, score=len(str(_msg)) + 50)
# database:save redis
Expand All @@ -199,20 +200,29 @@ def scraper_create_message(self, write_back=True, system_prompt=True):
async def request_openai(
self,
auto_write_back: bool,
retrieve_mode: bool = False,
disable_function: bool = False,
) -> LlmResult:
"""
处理消息转换和调用工具
:param auto_write_back: 是否自动写回
:param disable_function: 禁用函数
:param retrieve_mode: 是否为检索模式,当我们需要重新处理超长消息时候,需要设定为 True
:return: OpenaiResult
"""
run_driver_model = (
self.driver.model if not retrieve_mode else self.driver.retrieve_model
)
run_driver_model = self.driver.model
endpoint_schema = self.get_schema(model_name=run_driver_model)
if not endpoint_schema:
logger.warning(
f"Openai model {run_driver_model} not found, use {run_driver_model} instead"
)
endpoint_schema = SingleModel(
llm_model=run_driver_model,
token_limit=8192,
request=Openai,
response=OpenaiResult,
schema_type="openai",
func_executor="tool_call",
exception=None,
)
# 添加函数定义的系统提示
if not disable_function:
for function_item in self.functions:
Expand All @@ -232,9 +242,6 @@ async def request_openai(
]
# 构建消息列表
self.scraper_create_message(write_back=auto_write_back)
# 折叠消息列表
if retrieve_mode:
self.scraper.fold_message()
# 削减消息列表
self.scraper.reduce_messages(
limit=endpoint_schema.token_limit, model_name=run_driver_model
Expand All @@ -253,7 +260,6 @@ async def request_openai(
f"\n--org {self.driver.org_id} "
f"\n--model {run_driver_model} "
f"\n--function {functions}"
f"\n--retrieve_mode {retrieve_mode}"
)
# 禁用函数?
if disable_function or not functions:
Expand Down Expand Up @@ -284,8 +290,7 @@ async def request_openai(
await CostControl.add_cost(
cost=UserCost.create_from_task(
uid=self.session_uid,
request_id=result.id,
cost=UserCost.Cost(
cost=Cost(
cost_by=self.task.receiver.platform,
token_usage=_usage,
token_uuid=self.driver.uuid,
Expand Down
Loading

0 comments on commit 3c3709c

Please sign in to comment.