Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: 修改引入风格 #598

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pkg/audit/gatherer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import requests

import pkg.utils.context
import pkg.utils.updater
from ..utils import context
from ..utils import updater


class DataGatherer:
Expand All @@ -33,7 +33,7 @@ class DataGatherer:
def __init__(self):
self.load_from_db()
try:
self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号
self.version_str = updater.get_current_tag() # 从updater模块获取版本号
except:
pass

Expand All @@ -47,7 +47,7 @@ def report_to_server(self, subservice_name: str, count: int):
def thread_func():

try:
config = pkg.utils.context.get_config()
config = context.get_config()
if not config.report_usage:
return
res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config.msg_source_adapter))
Expand All @@ -64,7 +64,7 @@ def get_usage(self, key_md5):
def report_text_model_usage(self, model, total_tokens):
"""调用方报告文字模型请求文字使用量"""

key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存

if key_md5 not in self.usage:
self.usage[key_md5] = {}
Expand All @@ -84,7 +84,7 @@ def report_text_model_usage(self, model, total_tokens):
def report_image_model_usage(self, size):
"""调用方报告图片模型请求图片使用量"""

key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5()

if key_md5 not in self.usage:
self.usage[key_md5] = {}
Expand Down Expand Up @@ -131,9 +131,9 @@ def get_total_text_length(self):
return total

def dump_to_db(self):
pkg.utils.context.get_database_manager().dump_usage_json(self.usage)
context.get_database_manager().dump_usage_json(self.usage)

def load_from_db(self):
json_str = pkg.utils.context.get_database_manager().load_usage_json()
json_str = context.get_database_manager().load_usage_json()
if json_str is not None:
self.usage = json.loads(json_str)
9 changes: 4 additions & 5 deletions pkg/database/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import json
import logging
import time
from sqlite3 import Cursor

import sqlite3

import pkg.utils.context
from ..utils import context


class DatabaseManager:
Expand All @@ -22,7 +21,7 @@ def __init__(self):

self.reconnect()

pkg.utils.context.set_database_manager(self)
context.set_database_manager(self)

# 连接到数据库文件
def reconnect(self):
Expand All @@ -33,7 +32,7 @@ def reconnect(self):
def close(self):
self.conn.close()

def __execute__(self, *args, **kwargs) -> Cursor:
def __execute__(self, *args, **kwargs) -> sqlite3.Cursor:
# logging.debug('SQL: {}'.format(sql))
logging.debug('SQL: {}'.format(args))
c = self.cursor.execute(*args, **kwargs)
Expand Down Expand Up @@ -145,7 +144,7 @@ def set_session_expired(self, session_name: str, create_timestamp: int):
# 从数据库加载还没过期的session数据
def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
config = context.get_config()
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {}
Expand Down
17 changes: 8 additions & 9 deletions pkg/openai/api/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import openai
from openai.types.chat import chat_completion_message
import json
import logging

from .model import RequestBase
import openai
from openai.types.chat import chat_completion_message

from ..funcmgr import get_func_schema_list, execute_function, get_func, get_func_schema, ContentFunctionNotFoundError
from .model import RequestBase
from .. import funcmgr


class ChatCompletionRequest(RequestBase):
Expand Down Expand Up @@ -81,7 +81,7 @@ def __next__(self) -> dict:
"messages": self.messages,
}

funcs = get_func_schema_list()
funcs = funcmgr.get_func_schema_list()

if len(funcs) > 0:
args['functions'] = funcs
Expand Down Expand Up @@ -171,7 +171,7 @@ def __next__(self) -> dict:
# 若不是json格式的异常处理
except json.decoder.JSONDecodeError:
# 获取函数的参数列表
func_schema = get_func_schema(func_name)
func_schema = funcmgr.get_func_schema(func_name)

arguments = {
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
Expand All @@ -182,7 +182,7 @@ def __next__(self) -> dict:
# 执行函数调用
ret = ""
try:
ret = execute_function(func_name, arguments)
ret = funcmgr.execute_function(func_name, arguments)

logging.info("函数执行完成。")
except Exception as e:
Expand Down Expand Up @@ -216,6 +216,5 @@ def __next__(self) -> dict:
}
}

except ContentFunctionNotFoundError:
except funcmgr.ContentFunctionNotFoundError:
raise Exception("没有找到函数: {}".format(func_name))

4 changes: 2 additions & 2 deletions pkg/openai/api/completion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import openai
from openai.types import completion, completion_choice

from .model import RequestBase
from . import model


class CompletionRequest(RequestBase):
class CompletionRequest(model.RequestBase):
"""调用Completion接口的请求类。

调用方可以一直next completion直到finish_reason为stop。
Expand Down
2 changes: 0 additions & 2 deletions pkg/openai/api/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# 定义不同接口请求的模型
import threading
import asyncio
import logging

import openai
Expand Down
3 changes: 1 addition & 2 deletions pkg/openai/funcmgr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# 封装了function calling的一些支持函数
import logging


from pkg.plugin import host
from ..plugin import host


class ContentFunctionNotFoundError(Exception):
Expand Down
4 changes: 2 additions & 2 deletions pkg/openai/keymgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import hashlib
import logging

import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models


class KeysManager:
Expand Down
29 changes: 14 additions & 15 deletions pkg/openai/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import openai

import pkg.openai.keymgr
import pkg.utils.context
import pkg.audit.gatherer
from pkg.openai.modelmgr import select_request_cls

from pkg.openai.api.model import RequestBase
from ..openai import keymgr
from ..utils import context
from ..audit import gatherer
from ..openai import modelmgr
from ..openai.api import model as api_model


class OpenAIInteract:
Expand All @@ -16,9 +15,9 @@ class OpenAIInteract:
将文字接口和图片接口封装供调用方使用
"""

key_mgr: pkg.openai.keymgr.KeysManager = None
key_mgr: keymgr.KeysManager = None

audit_mgr: pkg.audit.gatherer.DataGatherer = None
audit_mgr: gatherer.DataGatherer = None

default_image_api_params = {
"size": "256x256",
Expand All @@ -28,31 +27,31 @@ class OpenAIInteract:

def __init__(self, api_key: str):

self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
self.audit_mgr = pkg.audit.gatherer.DataGatherer()
self.key_mgr = keymgr.KeysManager(api_key)
self.audit_mgr = gatherer.DataGatherer()

# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())

self.client = openai.Client(
api_key=self.key_mgr.get_using_key()
)

pkg.utils.context.set_openai_manager(self)
context.set_openai_manager(self)

def request_completion(self, messages: list):
"""请求补全接口回复=
"""
# 选择接口请求类
config = pkg.utils.context.get_config()
config = context.get_config()

request: RequestBase
request: api_model.RequestBase

model: str = config.completion_api_params['model']

cp_parmas = config.completion_api_params.copy()
del cp_parmas['model']

request = select_request_cls(self.client, model, messages, cp_parmas)
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)

# 请求接口
for resp in request:
Expand All @@ -74,7 +73,7 @@ def request_image(self, prompt) -> dict:
Returns:
dict: 响应
"""
config = pkg.utils.context.get_config()
config = context.get_config()
params = config.image_api_params

response = openai.Image.create(
Expand Down
12 changes: 6 additions & 6 deletions pkg/openai/modelmgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import tiktoken
import openai

from pkg.openai.api.model import RequestBase
from pkg.openai.api.completion import CompletionRequest
from pkg.openai.api.chat_completion import ChatCompletionRequest
from ..openai.api import model as api_model
from ..openai.api import completion as api_completion
from ..openai.api import chat_completion as api_chat_completion

COMPLETION_MODELS = {
"text-davinci-003", # legacy
Expand Down Expand Up @@ -60,11 +60,11 @@
}


def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> RequestBase:
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> api_model.RequestBase:
if model_name in CHAT_COMPLETION_MODELS:
return ChatCompletionRequest(client, model_name, messages, **args)
return api_chat_completion.ChatCompletionRequest(client, model_name, messages, **args)
elif model_name in COMPLETION_MODELS:
return CompletionRequest(client, model_name, messages, **args)
return api_completion.CompletionRequest(client, model_name, messages, **args)
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))


Expand Down
Loading
Loading