-
Notifications
You must be signed in to change notification settings - Fork 229
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
307 additions
and
77 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from urllib.parse import urlparse | ||
|
||
import requests | ||
from pydantic import BaseModel | ||
|
||
|
||
class ProviderError(Exception): | ||
pass | ||
|
||
|
||
class Credential(BaseModel): | ||
api_key: str | ||
api_endpoint: str | ||
api_model: str | ||
|
||
@classmethod | ||
def from_provider(cls, token, provider_url): | ||
""" | ||
使用 token POST 请求 provider_url 获取用户信息 | ||
:param token: 用户 token | ||
:param provider_url: provider url | ||
:return: 用户信息 | ||
:raises HTTPError: 请求失败 | ||
:raises JSONDecodeError: 返回数据解析失败 | ||
:raises ProviderError: provider 返回错误信息 | ||
""" | ||
response = requests.post(provider_url, data={"token": token}) | ||
response.raise_for_status() | ||
user_data = response.json() | ||
if user_data.get("error"): | ||
raise ProviderError(user_data["error"]) | ||
return cls( | ||
api_key=user_data["api_key"], | ||
api_endpoint=user_data["api_endpoint"], | ||
api_model=user_data["api_model"], | ||
) | ||
|
||
|
||
def split_setting_string(input_string): | ||
if not isinstance(input_string, str): | ||
return None | ||
segments = input_string.split("$") | ||
|
||
# 检查链接的有效性 | ||
def is_valid_url(url): | ||
try: | ||
result = urlparse(url) | ||
return all([result.scheme, result.netloc]) | ||
except ValueError: | ||
return False | ||
|
||
# 开头为链接的情况 | ||
if is_valid_url(segments[0]) and len(segments) >= 3: | ||
return segments[:3] | ||
# 第二个元素为链接,第一个元素为字符串的情况 | ||
elif ( | ||
len(segments) == 2 | ||
and not is_valid_url(segments[0]) | ||
and is_valid_url(segments[1]) | ||
): | ||
return segments | ||
# 其他情况 | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2024/2/8 下午10:56 | ||
# @Author : sudoskys | ||
# @File : __init__.py.py | ||
# @Software: PyCharm | ||
import time | ||
from typing import Optional | ||
|
||
from loguru import logger | ||
from pydantic import BaseModel | ||
|
||
from app.components.credential import Credential | ||
from app.const import DBNAME | ||
from llmkira.doc_manager import global_doc_client | ||
|
||
|
||
class ChatCost(BaseModel): | ||
user_id: str | ||
cost_token: int = 0 | ||
endpoint: str = "" | ||
cost_model: str = "" | ||
produce_time: int = time.time() | ||
|
||
|
||
class GenerateHistory(object): | ||
def __init__(self, db_name: str = DBNAME, collection: str = "cost_history"): | ||
""" """ | ||
self.client = global_doc_client.update_db_collection( | ||
db_name=db_name, collection_name=collection | ||
) | ||
|
||
async def save(self, history: ChatCost): | ||
return self.client.insert_one(history.model_dump(mode="json")) | ||
|
||
|
||
class User(BaseModel): | ||
user_id: str | ||
last_use_time: int = time.time() | ||
credential: Optional[Credential] = None | ||
|
||
|
||
class UserManager(object): | ||
def __init__(self, db_name: str = DBNAME, collection: str = "user"): | ||
""" """ | ||
self.client = global_doc_client.update_db_collection( | ||
db_name=db_name, collection_name=collection | ||
) | ||
|
||
async def read(self, user_id: str) -> User: | ||
user_id = str(user_id) | ||
database_read = self.client.find_one({"user_id": user_id}) | ||
if not database_read: | ||
logger.info(f"Create new user: {user_id}") | ||
return User(user_id=user_id) | ||
# database_read.update({"user_id": user_id}) | ||
return User.model_validate(database_read) | ||
|
||
async def save(self, user_model: User): | ||
user_model = user_model.model_copy(update={"last_use_time": int(time.time())}) | ||
# 如果存在记录则更新 | ||
if self.client.find_one({"user_id": user_model.user_id}): | ||
return self.client.update_one( | ||
{"user_id": user_model.user_id}, | ||
{"$set": user_model.model_dump(mode="json")}, | ||
) | ||
# 如果不存在记录则插入 | ||
else: | ||
return self.client.insert_one(user_model.model_dump(mode="json")) | ||
|
||
|
||
COST_MANAGER = GenerateHistory() | ||
USER_MANAGER = UserManager() | ||
|
||
|
||
async def record_cost( | ||
user_id: str, cost_token: int, endpoint: str, cost_model: str, success: bool = True | ||
): | ||
try: | ||
await COST_MANAGER.save( | ||
ChatCost( | ||
user_id=user_id, | ||
produce_time=int(time.time()), | ||
endpoint=endpoint, | ||
cost_model=cost_model, | ||
cost_token=cost_token if success else 0, | ||
) | ||
) | ||
except Exception as exc: | ||
logger.error(f"🔥 record_cost error: {exc}") | ||
|
||
|
||
if __name__ == "__main__": | ||
pass |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.