Skip to content

Commit

Permalink
doc: 添加部分注释
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ committed Mar 5, 2023
1 parent e4b581f commit 651b291
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 70 deletions.
3 changes: 3 additions & 0 deletions pkg/audit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
审计相关操作
"""
26 changes: 23 additions & 3 deletions pkg/audit/gatherer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
使用量统计以及数据上报功能实现
"""

import hashlib
import json
import logging
Expand All @@ -10,8 +14,11 @@

class DataGatherer:
"""数据收集器"""

usage = {}
"""以key值md5为key,{
"""各api-key的使用量
以key值md5为key,{
"text": {
"text-davinci-003": 文字量:int,
},
Expand All @@ -25,11 +32,16 @@ class DataGatherer:
def __init__(self):
self.load_from_db()
try:
self.version_str = pkg.utils.updater.get_current_tag()
self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号
except:
pass

def report_to_server(self, subservice_name: str, count: int):
"""向中央服务器报告使用量
只会报告此次请求的使用量,不会报告总量。
不包含除版本号、使用类型、使用量以外的任何信息,仅供开发者分析使用情况。
"""
try:
config = pkg.utils.context.get_config()
if hasattr(config, "report_usage") and not config.report_usage:
Expand All @@ -44,7 +56,9 @@ def get_usage(self, key_md5):
return self.usage[key_md5] if key_md5 in self.usage else {}

def report_text_model_usage(self, model, total_tokens):
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
"""调用方报告文字模型请求文字使用量"""

key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存

if key_md5 not in self.usage:
self.usage[key_md5] = {}
Expand All @@ -62,6 +76,8 @@ def report_text_model_usage(self, model, total_tokens):
self.report_to_server("text", length)

def report_image_model_usage(self, size):
"""调用方报告图片模型请求图片使用量"""

key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()

if key_md5 not in self.usage:
Expand All @@ -79,6 +95,7 @@ def report_image_model_usage(self, size):
self.report_to_server("image", 1)

def get_text_length_of_key(self, key):
"""获取指定api-key (明文) 的文字总使用量(本地记录)"""
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
if key_md5 not in self.usage:
return 0
Expand All @@ -88,6 +105,8 @@ def get_text_length_of_key(self, key):
return sum(self.usage[key_md5]["text"].values())

def get_image_count_of_key(self, key):
"""获取指定api-key (明文) 的图片总使用量(本地记录)"""

key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
if key_md5 not in self.usage:
return 0
Expand All @@ -97,6 +116,7 @@ def get_image_count_of_key(self, key):
return sum(self.usage[key_md5]["image"].values())

def get_total_text_length(self):
"""获取所有api-key的文字总使用量(本地记录)"""
total = 0
for key in self.usage:
if "text" not in self.usage[key]:
Expand Down
3 changes: 3 additions & 0 deletions pkg/database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
数据库操作封装
"""
61 changes: 35 additions & 26 deletions pkg/database/manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
数据库管理模块
"""
import hashlib
import json
import logging
Expand All @@ -9,9 +12,9 @@
import pkg.utils.context


# 数据库管理
# 为其他模块提供数据库操作接口
class DatabaseManager:
"""封装数据库底层操作,并提供方法给上层使用"""

conn = None
cursor = None

Expand All @@ -23,21 +26,24 @@ def __init__(self):

# 连接到数据库文件
def reconnect(self):
"""连接到数据库"""
self.conn = sqlite3.connect('database.db', check_same_thread=False)
self.cursor = self.conn.cursor()

def close(self):
self.conn.close()

def execute(self, *args, **kwargs) -> Cursor:
def __execute__(self, *args, **kwargs) -> Cursor:
# logging.debug('SQL: {}'.format(sql))
c = self.cursor.execute(*args, **kwargs)
self.conn.commit()
return c

# 初始化数据库的函数
def initialize_database(self):
self.execute("""
"""创建数据表"""

self.__execute__("""
create table if not exists `sessions` (
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`name` varchar(255) not null,
Expand All @@ -50,7 +56,7 @@ def initialize_database(self):
)
""")

self.execute("""
self.__execute__("""
create table if not exists `account_fee`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`key_md5` varchar(255) not null,
Expand All @@ -59,7 +65,7 @@ def initialize_database(self):
)
""")

self.execute("""
self.__execute__("""
create table if not exists `account_usage`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`json` text not null
Expand All @@ -70,10 +76,12 @@ def initialize_database(self):
# session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str):
"""持久化指定session"""

# 检查是否已经有了此name和create_timestamp的session
# 如果有,就更新prompt和last_interact_timestamp
# 如果没有,就插入一条新的记录
self.execute("""
self.__execute__("""
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(subject_type, subject_number, create_timestamp))
count = self.cursor.fetchone()[0]
Expand All @@ -84,40 +92,40 @@ def persistence_session(self, subject_type: str, subject_number: int, create_tim
values (?, ?, ?, ?, ?, ?)
"""

self.execute(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt))
else:
sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
where `type` = ? and `number` = ? and `create_timestamp` = ?
"""

self.execute(sql, (last_interact_timestamp, prompt, subject_type,
subject_number, create_timestamp))
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
subject_number, create_timestamp))

# 显式关闭一个session
def explicit_close_session(self, session_name: str, create_timestamp: int):
self.execute("""
self.__execute__("""
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))

def set_session_ongoing(self, session_name: str, create_timestamp: int):
self.execute("""
self.__execute__("""
update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))

# 设置session为过期
def set_session_expired(self, session_name: str, create_timestamp: int):
self.execute("""
self.__execute__("""
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp))

# 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
self.execute("""
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time))
Expand Down Expand Up @@ -150,7 +158,7 @@ def load_valid_sessions(self) -> dict:
# 获取此session_name前一个session的数据
def last_session(self, session_name: str, cursor_timestamp: int):

self.execute("""
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1
Expand Down Expand Up @@ -179,7 +187,7 @@ def last_session(self, session_name: str, cursor_timestamp: int):
# 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int):

self.execute("""
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1
Expand Down Expand Up @@ -207,7 +215,7 @@ def next_session(self, session_name: str, cursor_timestamp: int):

# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int):
self.execute("""
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page))
Expand Down Expand Up @@ -246,22 +254,22 @@ def dump_api_key_usage(self, api_keys: dict, usage: dict):
usage_count = usage[key_md5]
# 将使用量存进数据库
# 先检查是否已存在
self.execute("""
self.__execute__("""
select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5))
result = self.cursor.fetchone()
if result[0] == 0:
# 不存在则插入
self.execute("""
self.__execute__("""
insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {})
""".format(key_md5, usage_count, int(time.time())))
else:
# 存在则更新,timestamp设置为当前
self.execute("""
self.__execute__("""
update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}'
""".format(usage_count, int(time.time()), key_md5))

def load_api_key_usage(self):
self.execute("""
self.__execute__("""
select `key_md5`, `usage` from `api_key_usage`
""")
results = self.cursor.fetchall()
Expand All @@ -273,23 +281,24 @@ def load_api_key_usage(self):
return usage

def dump_usage_json(self, usage: dict):

json_str = json.dumps(usage)
self.execute("""
self.__execute__("""
select count(*) from `account_usage`""")
result = self.cursor.fetchone()
if result[0] == 0:
# 不存在则插入
self.execute("""
self.__execute__("""
insert into `account_usage` (`json`) values ('{}')
""".format(json_str))
else:
# 存在则更新
self.execute("""
self.__execute__("""
update `account_usage` set `json` = '{}' where `id` = 1
""".format(json_str))

def load_usage_json(self):
self.execute("""
self.__execute__("""
select `json` from `account_usage` order by id desc limit 1
""")
result = self.cursor.fetchone()
Expand Down
2 changes: 2 additions & 0 deletions pkg/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""OpenAI 接口处理及会话管理相关
"""
5 changes: 5 additions & 0 deletions pkg/openai/dprompt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# 多情景预设值管理

__current__ = "default"
"""当前默认使用的情景预设的名称
由管理员使用`!default <名称>`指令切换
"""

__prompts_from_files__ = {}
"""从文件中读取的情景预设值"""


def read_prompt_from_file() -> str:
Expand Down
33 changes: 20 additions & 13 deletions pkg/openai/keymgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@
import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models


class KeysManager:
api_key = {}
"""所有api-key"""

# api-key的使用量
# 其中键为api-key的md5值,值为使用量
using_key = ""
"""当前使用的api-key
"""

alerted = []
"""已提示过超额的key
记录在此以避免重复提示
"""

# 在此list中的都是经超额报错标记过的api-key
# 记录的是key值,仅在运行时有效
exceeded = []
"""已超额的key
供自动切换功能识别
"""

def get_using_key(self):
return self.using_key
Expand All @@ -25,8 +33,6 @@ def get_using_key_md5(self):
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()

def __init__(self, api_key):
# if hasattr(config, 'api_key_usage_threshold'):
# self.api_key_usage_threshold = config.api_key_usage_threshold

if type(api_key) is dict:
self.api_key = api_key
Expand All @@ -42,9 +48,13 @@ def __init__(self, api_key):

self.auto_switch()

# 根据tested自动切换到可用的api-key
# 返回是否切换成功, 切换后的api-key的别名
def auto_switch(self) -> (bool, str):
"""尝试切换api-key
Returns:
是否切换成功, 切换后的api-key的别名
"""

for key_name in self.api_key:
if self.api_key[key_name] not in self.exceeded:
self.using_key = self.api_key[key_name]
Expand All @@ -68,12 +78,9 @@ def auto_switch(self) -> (bool, str):
def add(self, key_name, key):
self.api_key[key_name] = key

# 设置当前使用的api-key使用量超限
# 这是在尝试调用api时发生超限异常时调用的
def set_current_exceeded(self):
# md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
# self.usage[md5] = self.api_key_usage_threshold
# self.fee[md5] = self.api_key_fee_threshold
"""设置当前使用的api-key使用量超限
"""
self.exceeded.append(self.using_key)

def get_key_name(self, api_key):
Expand Down
Loading

0 comments on commit 651b291

Please sign in to comment.