diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py index 699bbaee..acf6b368 100644 --- a/pkg/audit/gatherer.py +++ b/pkg/audit/gatherer.py @@ -9,8 +9,8 @@ import requests -import pkg.utils.context -import pkg.utils.updater +from ..utils import context +from ..utils import updater class DataGatherer: @@ -33,7 +33,7 @@ class DataGatherer: def __init__(self): self.load_from_db() try: - self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号 + self.version_str = updater.get_current_tag() # 从updater模块获取版本号 except: pass @@ -47,7 +47,7 @@ def report_to_server(self, subservice_name: str, count: int): def thread_func(): try: - config = pkg.utils.context.get_config() + config = context.get_config() if not config.report_usage: return res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}&msg_source={}".format(subservice_name, self.version_str, count, config.msg_source_adapter)) @@ -64,7 +64,7 @@ def get_usage(self, key_md5): def report_text_model_usage(self, model, total_tokens): """调用方报告文字模型请求文字使用量""" - key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存 + key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存 if key_md5 not in self.usage: self.usage[key_md5] = {} @@ -84,7 +84,7 @@ def report_text_model_usage(self, model, total_tokens): def report_image_model_usage(self, size): """调用方报告图片模型请求图片使用量""" - key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() + key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() if key_md5 not in self.usage: self.usage[key_md5] = {} @@ -131,9 +131,9 @@ def get_total_text_length(self): return total def dump_to_db(self): - pkg.utils.context.get_database_manager().dump_usage_json(self.usage) + context.get_database_manager().dump_usage_json(self.usage) def load_from_db(self): - json_str = pkg.utils.context.get_database_manager().load_usage_json() + json_str = context.get_database_manager().load_usage_json() if json_str is not None: self.usage = json.loads(json_str) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 33d6cfb8..15a00ae9 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -5,11 +5,10 @@ import json import logging import time -from sqlite3 import Cursor import sqlite3 -import pkg.utils.context +from ..utils import context class DatabaseManager: @@ -22,7 +21,7 @@ def __init__(self): self.reconnect() - pkg.utils.context.set_database_manager(self) + context.set_database_manager(self) # 连接到数据库文件 def reconnect(self): @@ -33,7 +32,7 @@ def reconnect(self): def close(self): self.conn.close() - def __execute__(self, *args, **kwargs) -> Cursor: + def __execute__(self, *args, **kwargs) -> sqlite3.Cursor: # logging.debug('SQL: {}'.format(sql)) logging.debug('SQL: {}'.format(args)) c = self.cursor.execute(*args, **kwargs) @@ -145,7 +144,7 @@ def set_session_expired(self, session_name: str, create_timestamp: int): # 从数据库加载还没过期的session数据 def load_valid_sessions(self) -> dict: # 从数据库中加载所有还没过期的session - config = pkg.utils.context.get_config() + config = context.get_config() self.__execute__(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `last_interact_timestamp` > {} diff --git a/pkg/openai/api/chat_completion.py b/pkg/openai/api/chat_completion.py index a8e5f175..e308f17d 100644 --- a/pkg/openai/api/chat_completion.py +++ b/pkg/openai/api/chat_completion.py @@ -1,11 +1,11 @@ -import openai -from openai.types.chat import chat_completion_message import json import logging -from .model import RequestBase +import openai +from openai.types.chat import chat_completion_message -from ..funcmgr import get_func_schema_list, execute_function, get_func, get_func_schema, ContentFunctionNotFoundError +from .model import RequestBase +from .. import funcmgr class ChatCompletionRequest(RequestBase): @@ -81,7 +81,7 @@ def __next__(self) -> dict: "messages": self.messages, } - funcs = get_func_schema_list() + funcs = funcmgr.get_func_schema_list() if len(funcs) > 0: args['functions'] = funcs @@ -171,7 +171,7 @@ def __next__(self) -> dict: # 若不是json格式的异常处理 except json.decoder.JSONDecodeError: # 获取函数的参数列表 - func_schema = get_func_schema(func_name) + func_schema = funcmgr.get_func_schema(func_name) arguments = { func_schema['parameters']['required'][0]: cp_pending_func_call.arguments @@ -182,7 +182,7 @@ def __next__(self) -> dict: # 执行函数调用 ret = "" try: - ret = execute_function(func_name, arguments) + ret = funcmgr.execute_function(func_name, arguments) logging.info("函数执行完成。") except Exception as e: @@ -216,6 +216,5 @@ def __next__(self) -> dict: } } - except ContentFunctionNotFoundError: + except funcmgr.ContentFunctionNotFoundError: raise Exception("没有找到函数: {}".format(func_name)) - diff --git a/pkg/openai/api/completion.py b/pkg/openai/api/completion.py index 2c74de36..d14e91f4 100644 --- a/pkg/openai/api/completion.py +++ b/pkg/openai/api/completion.py @@ -1,10 +1,10 @@ import openai from openai.types import completion, completion_choice -from .model import RequestBase +from . import model -class CompletionRequest(RequestBase): +class CompletionRequest(model.RequestBase): """调用Completion接口的请求类。 调用方可以一直next completion直到finish_reason为stop。 diff --git a/pkg/openai/api/model.py b/pkg/openai/api/model.py index 3a574cb3..3fd71fea 100644 --- a/pkg/openai/api/model.py +++ b/pkg/openai/api/model.py @@ -1,6 +1,4 @@ # 定义不同接口请求的模型 -import threading -import asyncio import logging import openai diff --git a/pkg/openai/funcmgr.py b/pkg/openai/funcmgr.py index 06c72e25..50932917 100644 --- a/pkg/openai/funcmgr.py +++ b/pkg/openai/funcmgr.py @@ -1,8 +1,7 @@ # 封装了function calling的一些支持函数 import logging - -from pkg.plugin import host +from ..plugin import host class ContentFunctionNotFoundError(Exception): diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index bed44330..ea9c292b 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -2,8 +2,8 @@ import hashlib import logging -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models +from ..plugin import host as plugin_host +from ..plugin import models as plugin_models class KeysManager: diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index b99371f6..4236c630 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -2,12 +2,11 @@ import openai -import pkg.openai.keymgr -import pkg.utils.context -import pkg.audit.gatherer -from pkg.openai.modelmgr import select_request_cls - -from pkg.openai.api.model import RequestBase +from ..openai import keymgr +from ..utils import context +from ..audit import gatherer +from ..openai import modelmgr +from ..openai.api import model as api_model class OpenAIInteract: @@ -16,9 +15,9 @@ class OpenAIInteract: 将文字接口和图片接口封装供调用方使用 """ - key_mgr: pkg.openai.keymgr.KeysManager = None + key_mgr: keymgr.KeysManager = None - audit_mgr: pkg.audit.gatherer.DataGatherer = None + audit_mgr: gatherer.DataGatherer = None default_image_api_params = { "size": "256x256", @@ -28,8 +27,8 @@ class OpenAIInteract: def __init__(self, api_key: str): - self.key_mgr = pkg.openai.keymgr.KeysManager(api_key) - self.audit_mgr = pkg.audit.gatherer.DataGatherer() + self.key_mgr = keymgr.KeysManager(api_key) + self.audit_mgr = gatherer.DataGatherer() # logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length()) @@ -37,22 +36,22 @@ def __init__(self, api_key: str): api_key=self.key_mgr.get_using_key() ) - pkg.utils.context.set_openai_manager(self) + context.set_openai_manager(self) def request_completion(self, messages: list): """请求补全接口回复= """ # 选择接口请求类 - config = pkg.utils.context.get_config() + config = context.get_config() - request: RequestBase + request: api_model.RequestBase model: str = config.completion_api_params['model'] cp_parmas = config.completion_api_params.copy() del cp_parmas['model'] - request = select_request_cls(self.client, model, messages, cp_parmas) + request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas) # 请求接口 for resp in request: @@ -74,7 +73,7 @@ def request_image(self, prompt) -> dict: Returns: dict: 响应 """ - config = pkg.utils.context.get_config() + config = context.get_config() params = config.image_api_params response = openai.Image.create( diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 44942d7f..6e7947bb 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -8,9 +8,9 @@ import tiktoken import openai -from pkg.openai.api.model import RequestBase -from pkg.openai.api.completion import CompletionRequest -from pkg.openai.api.chat_completion import ChatCompletionRequest +from ..openai.api import model as api_model +from ..openai.api import completion as api_completion +from ..openai.api import chat_completion as api_chat_completion COMPLETION_MODELS = { "text-davinci-003", # legacy @@ -60,11 +60,11 @@ } -def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> RequestBase: +def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> api_model.RequestBase: if model_name in CHAT_COMPLETION_MODELS: - return ChatCompletionRequest(client, model_name, messages, **args) + return api_chat_completion.ChatCompletionRequest(client, model_name, messages, **args) elif model_name in COMPLETION_MODELS: - return CompletionRequest(client, model_name, messages, **args) + return api_completion.CompletionRequest(client, model_name, messages, **args) raise ValueError("不支持模型[{}],请检查配置文件".format(model_name)) diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 6277f065..5c351478 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -8,15 +8,13 @@ import time import json -import pkg.openai.manager -import pkg.openai.modelmgr -import pkg.database.manager -import pkg.utils.context +from ..openai import manager as openai_manager +from ..openai import modelmgr as openai_modelmgr +from ..database import manager as database_manager +from ..utils import context as context -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models - -from pkg.openai.modelmgr import count_tokens +from ..plugin import host as plugin_host +from ..plugin import models as plugin_models # 运行时保存的所有session sessions = {} @@ -38,7 +36,7 @@ def reset_session_prompt(session_name, prompt): f.write(prompt) f.close() # 生成新数据 - config = pkg.utils.context.get_config() + config = context.get_config() prompt = [ { 'role': 'system', @@ -61,7 +59,7 @@ def load_sessions(): global sessions - db_inst = pkg.utils.context.get_database_manager() + db_inst = context.get_database_manager() session_data = db_inst.load_valid_sessions() @@ -172,7 +170,7 @@ def expire_check_timer_loop(self, create_timestamp: int): if self.create_timestamp != create_timestamp or self not in sessions.values(): return - config = pkg.utils.context.get_config() + config = context.get_config() if int(time.time()) - self.last_interact_timestamp > config.session_expire_time: logging.info('session {} 已过期'.format(self.name)) @@ -182,7 +180,7 @@ def expire_check_timer_loop(self, create_timestamp: int): 'session': self, 'session_expire_time': config.session_expire_time } - event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args) + event = plugin_host.emit(plugin_models.SessionExpired, **args) if event.is_prevented_default(): return @@ -214,11 +212,11 @@ def query(self, text: str=None) -> tuple[str, str, list[str]]: 'default_prompt': self.default_prompt, } - event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args) + event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args) if event.is_prevented_default(): return None, None, None - config = pkg.utils.context.get_config() + config = context.get_config() max_length = config.prompt_submit_length local_default_prompt = self.default_prompt.copy() @@ -232,7 +230,7 @@ def query(self, text: str=None) -> tuple[str, str, list[str]]: 'text_message': text, } - event = pkg.plugin.host.emit(plugin_models.PromptPreProcessing, **args) + event = plugin_host.emit(plugin_models.PromptPreProcessing, **args) if event.get_return_value('default_prompt') is not None: local_default_prompt = event.get_return_value('default_prompt') @@ -256,14 +254,14 @@ def query(self, text: str=None) -> tuple[str, str, list[str]]: funcs = [] trace_func_calls = config.trace_function_calls - botmgr = pkg.utils.context.get_qqbot_manager() + botmgr = context.get_qqbot_manager() session_name_spt: list[str] = self.name.split("_") pending_res_text = "" # TODO 对不起,我知道这样非常非常屎山,但我之后会重构的 - for resp in pkg.utils.context.get_openai_manager().request_completion(prompts): + for resp in context.get_openai_manager().request_completion(prompts): if pending_res_text != "": botmgr.adapter.send_message( @@ -325,7 +323,6 @@ def query(self, text: str=None) -> tuple[str, str, list[str]]: ) pass - # 向API请求补全 # message, total_token = pkg.utils.context.get_openai_manager().request_completion( # prompts, @@ -383,13 +380,13 @@ def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) # 包装目前的对话回合内容 changable_prompts = [] - use_model = pkg.utils.context.get_config().completion_api_params['model'] + use_model = context.get_config().completion_api_params['model'] ptr = len(prompt) - 1 # 直接从后向前扫描拼接,不管是否是整回合 while ptr >= 0: - if count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: + if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens: break changable_prompts.insert(0, prompt[ptr]) @@ -410,14 +407,14 @@ def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4))) - return result_prompt, count_tokens(changable_prompts, use_model) + return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model) # 持久化session def persistence(self): if self.prompt == self.get_default_prompt(): return - db_inst = pkg.utils.context.get_database_manager() + db_inst = context.get_database_manager() name_spt = self.name.split('_') @@ -439,12 +436,12 @@ def reset(self, explicit: bool = False, expired: bool = False, schedule_new: boo } # 此事件不支持阻止默认行为 - _ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args) + _ = plugin_host.emit(plugin_models.SessionExplicitReset, **args) - pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) + context.get_database_manager().explicit_close_session(self.name, self.create_timestamp) if expired: - pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp) + context.get_database_manager().set_session_expired(self.name, self.create_timestamp) if not persist: # 不要求保持default prompt self.default_prompt = self.get_default_prompt(use_prompt) @@ -461,11 +458,11 @@ def reset(self, explicit: bool = False, expired: bool = False, schedule_new: boo # 将本session的数据库状态设置为on_going def set_ongoing(self): - pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp) + context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp) # 切换到上一个session def last_session(self): - last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp) + last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp) if last_one is None: return None else: @@ -486,7 +483,7 @@ def last_session(self): # 切换到下一个session def next_session(self): - next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp) + next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp) if next_one is None: return None else: @@ -506,13 +503,13 @@ def next_session(self): return self def list_history(self, capacity: int = 10, page: int = 0): - return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page) + return context.get_database_manager().list_history(self.name, capacity, page) def delete_history(self, index: int) -> bool: - return pkg.utils.context.get_database_manager().delete_history(self.name, index) + return context.get_database_manager().delete_history(self.name, index) def delete_all_history(self) -> bool: - return pkg.utils.context.get_database_manager().delete_all_history(self.name) + return context.get_database_manager().delete_all_history(self.name) def draw_image(self, prompt: str): - return pkg.utils.context.get_openai_manager().request_image(prompt) + return context.get_openai_manager().request_image(prompt) diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index 27a6c1e2..27806845 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -10,13 +10,13 @@ import time import re -import pkg.utils.updater as updater -import pkg.utils.context as context -import pkg.plugin.switch as switch -import pkg.plugin.settings as settings -import pkg.qqbot.adapter as msadapter -import pkg.utils.network as network -import pkg.plugin.metadata as metadata +from ..utils import updater as updater +from ..utils import network as network +from ..utils import context as context +from ..plugin import switch as switch +from ..plugin import settings as settings +from ..qqbot import adapter as msadapter +from ..plugin import metadata as metadata from mirai import Mirai import requests @@ -147,6 +147,7 @@ def initialize_plugins(): successfully_initialized_plugins.append(plugin['name']) except: logging.error("插件{}初始化时发生错误: {}".format(plugin['name'], sys.exc_info())) + logging.debug(traceback.format_exc()) logging.info("以下插件已初始化: {}".format(", ".join(successfully_initialized_plugins))) diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index b2280683..2e1ce459 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -1,7 +1,7 @@ import logging -import pkg.plugin.host as host -import pkg.utils.context +from ..plugin import host +from ..utils import context PersonMessageReceived = "person_message_received" """收到私聊消息时,在判断是否应该响应前触发 @@ -285,7 +285,7 @@ def wrapper(cls: Plugin): cls.description = description cls.version = version cls.author = author - cls.host = pkg.utils.context.get_plugin_host() + cls.host = context.get_plugin_host() cls.enabled = True cls.path = host.__current_module_path__ diff --git a/pkg/plugin/settings.py b/pkg/plugin/settings.py index 92fcfe77..6824906a 100644 --- a/pkg/plugin/settings.py +++ b/pkg/plugin/settings.py @@ -1,9 +1,9 @@ import json import os -import pkg.plugin.host as host import logging +from ..plugin import host def wrapper_dict_from_runtime_context() -> dict: """从变量中包装settings.json的数据字典""" diff --git a/pkg/plugin/switch.py b/pkg/plugin/switch.py index 041ec128..ccc96c8c 100644 --- a/pkg/plugin/switch.py +++ b/pkg/plugin/switch.py @@ -3,7 +3,7 @@ import logging import os -import pkg.plugin.host as host +from ..plugin import host def wrapper_dict_from_plugin_list() -> dict: diff --git a/pkg/qqbot/banlist.py b/pkg/qqbot/banlist.py index 2c7dcb12..949c541b 100644 --- a/pkg/qqbot/banlist.py +++ b/pkg/qqbot/banlist.py @@ -1,18 +1,18 @@ -import pkg.utils.context +from ..utils import context def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool: - if not pkg.utils.context.get_qqbot_manager().enable_banlist: + if not context.get_qqbot_manager().enable_banlist: return False result = False if launcher_type == 'group': # 检查是否显式声明发起人QQ要被person忽略 - if sender_id in pkg.utils.context.get_qqbot_manager().ban_person: + if sender_id in context.get_qqbot_manager().ban_person: result = True else: - for group_rule in pkg.utils.context.get_qqbot_manager().ban_group: + for group_rule in context.get_qqbot_manager().ban_group: if type(group_rule) == int: if group_rule == launcher_id: # 此群群号被禁用 result = True @@ -32,7 +32,7 @@ def is_banned(launcher_type: str, launcher_id: int, sender_id: int) -> bool: else: # ban_person, 与群规则相同 - for person_rule in pkg.utils.context.get_qqbot_manager().ban_person: + for person_rule in context.get_qqbot_manager().ban_person: if type(person_rule) == int: if person_rule == launcher_id: result = True diff --git a/pkg/qqbot/blob.py b/pkg/qqbot/blob.py index 099d8013..fd66a4bc 100644 --- a/pkg/qqbot/blob.py +++ b/pkg/qqbot/blob.py @@ -2,21 +2,21 @@ import os import time import base64 +import typing -import config from mirai.models.message import MessageComponent, MessageChain, Image from mirai.models.message import ForwardMessageNode from mirai.models.base import MiraiBaseModel -from typing import List -import pkg.utils.context as context -import pkg.utils.text2img as text2img + +from ..utils import text2img +import config class ForwardMessageDiaplay(MiraiBaseModel): title: str = "群聊的聊天记录" brief: str = "[聊天记录]" source: str = "聊天记录" - preview: List[str] = [] + preview: typing.List[str] = [] summary: str = "查看x条转发消息" @@ -26,7 +26,7 @@ class Forward(MessageComponent): """消息组件类型。""" display: ForwardMessageDiaplay """显示信息""" - node_list: List[ForwardMessageNode] + node_list: typing.List[ForwardMessageNode] """转发消息节点列表。""" def __init__(self, *args, **kwargs): if len(args) == 1: diff --git a/pkg/qqbot/cmds/aamgr.py b/pkg/qqbot/cmds/aamgr.py index cfc95b5a..27596c4c 100644 --- a/pkg/qqbot/cmds/aamgr.py +++ b/pkg/qqbot/cmds/aamgr.py @@ -1,10 +1,7 @@ -import importlib -import inspect import logging import copy import pkgutil import traceback -import types import json diff --git a/pkg/qqbot/cmds/funcs/draw.py b/pkg/qqbot/cmds/funcs/draw.py index fecd3a34..b9af92e9 100644 --- a/pkg/qqbot/cmds/funcs/draw.py +++ b/pkg/qqbot/cmds/funcs/draw.py @@ -1,11 +1,12 @@ -from ..aamgr import AbstractCommandNode, Context import logging -from mirai import Image +import mirai + +from .. import aamgr import config -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="draw", description="使用DALL·E生成图片", @@ -13,9 +14,9 @@ aliases=[], privilege=1 ) -class DrawCommand(AbstractCommandNode): +class DrawCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session reply = [] @@ -28,7 +29,7 @@ def process(cls, ctx: Context) -> tuple[bool, list]: res = session.draw_image(" ".join(ctx.params)) logging.debug("draw_image result:{}".format(res)) - reply = [Image(url=res['data'][0]['url'])] + reply = [mirai.Image(url=res['data'][0]['url'])] if not (hasattr(config, 'include_image_description') and not config.include_image_description): reply.append(" ".join(ctx.params)) diff --git a/pkg/qqbot/cmds/funcs/func.py b/pkg/qqbot/cmds/funcs/func.py index 93b31844..61675931 100644 --- a/pkg/qqbot/cmds/funcs/func.py +++ b/pkg/qqbot/cmds/funcs/func.py @@ -1,10 +1,9 @@ -from ..aamgr import AbstractCommandNode, Context import logging - import json +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="func", description="管理内容函数", @@ -12,9 +11,9 @@ aliases=[], privilege=1 ) -class FuncCommand(AbstractCommandNode): +class FuncCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: from pkg.plugin.models import host reply = [] diff --git a/pkg/qqbot/cmds/plugin/plugin.py b/pkg/qqbot/cmds/plugin/plugin.py index 00ba44f0..21783fc7 100644 --- a/pkg/qqbot/cmds/plugin/plugin.py +++ b/pkg/qqbot/cmds/plugin/plugin.py @@ -1,12 +1,9 @@ -from ..aamgr import AbstractCommandNode, Context +from ....plugin import host as plugin_host +from ....utils import updater +from .. import aamgr -import os -import pkg.plugin.host as plugin_host -import pkg.utils.updater as updater - - -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="plugin", description="插件管理", @@ -14,9 +11,9 @@ aliases=[], privilege=1 ) -class PluginCommand(AbstractCommandNode): +class PluginCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: reply = [] plugin_list = plugin_host.__plugins__ if len(ctx.params) == 0: @@ -48,7 +45,7 @@ def process(cls, ctx: Context) -> tuple[bool, list]: return False, [] -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=PluginCommand, name="get", description="安装插件", @@ -56,9 +53,9 @@ def process(cls, ctx: Context) -> tuple[bool, list]: aliases=[], privilege=2 ) -class PluginGetCommand(AbstractCommandNode): +class PluginGetCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import threading import logging import pkg.utils.context @@ -81,7 +78,7 @@ def closure(): return True, reply -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=PluginCommand, name="update", description="更新指定插件或全部插件", @@ -89,9 +86,9 @@ def closure(): aliases=[], privilege=2 ) -class PluginUpdateCommand(AbstractCommandNode): +class PluginUpdateCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import threading import logging plugin_list = plugin_host.__plugins__ @@ -130,7 +127,7 @@ def closure(): return True, reply -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=PluginCommand, name="del", description="删除插件", @@ -138,9 +135,9 @@ def closure(): aliases=[], privilege=2 ) -class PluginDelCommand(AbstractCommandNode): +class PluginDelCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: plugin_list = plugin_host.__plugins__ reply = [] @@ -157,7 +154,7 @@ def process(cls, ctx: Context) -> tuple[bool, list]: return True, reply -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=PluginCommand, name="on", description="启用指定插件", @@ -165,7 +162,7 @@ def process(cls, ctx: Context) -> tuple[bool, list]: aliases=[], privilege=2 ) -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=PluginCommand, name="off", description="禁用指定插件", @@ -173,9 +170,9 @@ def process(cls, ctx: Context) -> tuple[bool, list]: aliases=[], privilege=2 ) -class PluginOnOffCommand(AbstractCommandNode): +class PluginOnOffCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.plugin.switch as plugin_switch plugin_list = plugin_host.__plugins__ diff --git a/pkg/qqbot/cmds/session/default.py b/pkg/qqbot/cmds/session/default.py index 1e094525..bb187123 100644 --- a/pkg/qqbot/cmds/session/default.py +++ b/pkg/qqbot/cmds/session/default.py @@ -1,7 +1,6 @@ -from ..aamgr import AbstractCommandNode, Context +from .. import aamgr - -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="default", description="操作情景预设", @@ -9,9 +8,9 @@ aliases=[], privilege=1 ) -class DefaultCommand(AbstractCommandNode): +class DefaultCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session session_name = ctx.session_name params = ctx.params @@ -45,7 +44,7 @@ def process(cls, ctx: Context) -> tuple[bool, list]: return True, reply -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=DefaultCommand, name="set", description="设置默认情景预设", @@ -53,9 +52,9 @@ def process(cls, ctx: Context) -> tuple[bool, list]: aliases=[], privilege=2 ) -class DefaultSetCommand(AbstractCommandNode): +class DefaultSetCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: reply = [] if len(ctx.crt_params) == 0: diff --git a/pkg/qqbot/cmds/session/del.py b/pkg/qqbot/cmds/session/del.py index a997f06b..45fdc4ee 100644 --- a/pkg/qqbot/cmds/session/del.py +++ b/pkg/qqbot/cmds/session/del.py @@ -1,8 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context -import datetime +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="del", description="删除当前会话的历史记录", @@ -10,9 +9,9 @@ aliases=[], privilege=1 ) -class DelCommand(AbstractCommandNode): +class DelCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session session_name = ctx.session_name params = ctx.params @@ -33,7 +32,7 @@ def process(cls, ctx: Context) -> tuple[bool, list]: return True, reply -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=DelCommand, name="all", description="删除当前会话的全部历史记录", @@ -41,9 +40,9 @@ def process(cls, ctx: Context) -> tuple[bool, list]: aliases=[], privilege=1 ) -class DelAllCommand(AbstractCommandNode): +class DelAllCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session session_name = ctx.session_name reply = [] diff --git a/pkg/qqbot/cmds/session/delhst.py b/pkg/qqbot/cmds/session/delhst.py index 3e0da85a..31791492 100644 --- a/pkg/qqbot/cmds/session/delhst.py +++ b/pkg/qqbot/cmds/session/delhst.py @@ -1,7 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="delhst", description="删除指定会话的所有历史记录", @@ -9,9 +9,9 @@ aliases=[], privilege=2 ) -class DelHistoryCommand(AbstractCommandNode): +class DelHistoryCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session import pkg.utils.context params = ctx.params @@ -31,7 +31,7 @@ def process(cls, ctx: Context) -> tuple[bool, list]: return True, reply -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=DelHistoryCommand, name="all", description="删除所有会话的全部历史记录", @@ -39,9 +39,9 @@ def process(cls, ctx: Context) -> tuple[bool, list]: aliases=[], privilege=2 ) -class DelAllHistoryCommand(AbstractCommandNode): +class DelAllHistoryCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.utils.context reply = [] pkg.utils.context.get_database_manager().delete_all_session_history() diff --git a/pkg/qqbot/cmds/session/last.py b/pkg/qqbot/cmds/session/last.py index bdf456be..93459c44 100644 --- a/pkg/qqbot/cmds/session/last.py +++ b/pkg/qqbot/cmds/session/last.py @@ -1,8 +1,9 @@ -from ..aamgr import AbstractCommandNode, Context import datetime +from .. import aamgr -@AbstractCommandNode.register( + +@aamgr.AbstractCommandNode.register( parent=None, name="last", description="切换前一次对话", @@ -10,9 +11,9 @@ aliases=[], privilege=1 ) -class LastCommand(AbstractCommandNode): +class LastCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session session_name = ctx.session_name diff --git a/pkg/qqbot/cmds/session/list.py b/pkg/qqbot/cmds/session/list.py index 6c31a9be..e40fb7fb 100644 --- a/pkg/qqbot/cmds/session/list.py +++ b/pkg/qqbot/cmds/session/list.py @@ -1,9 +1,10 @@ -from ..aamgr import AbstractCommandNode, Context import datetime import json +from .. import aamgr -@AbstractCommandNode.register( + +@aamgr.AbstractCommandNode.register( parent=None, name='list', description='列出当前会话的所有历史记录', @@ -11,9 +12,9 @@ aliases=[], privilege=1 ) -class ListCommand(AbstractCommandNode): +class ListCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session session_name = ctx.session_name params = ctx.params diff --git a/pkg/qqbot/cmds/session/next.py b/pkg/qqbot/cmds/session/next.py index 94622b0d..7704acf6 100644 --- a/pkg/qqbot/cmds/session/next.py +++ b/pkg/qqbot/cmds/session/next.py @@ -1,8 +1,9 @@ -from ..aamgr import AbstractCommandNode, Context import datetime +from .. import aamgr -@AbstractCommandNode.register( + +@aamgr.AbstractCommandNode.register( parent=None, name="next", description="切换后一次对话", @@ -10,9 +11,9 @@ aliases=[], privilege=1 ) -class NextCommand(AbstractCommandNode): +class NextCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session session_name = ctx.session_name reply = [] diff --git a/pkg/qqbot/cmds/session/prompt.py b/pkg/qqbot/cmds/session/prompt.py index d2629bd3..adb2e583 100644 --- a/pkg/qqbot/cmds/session/prompt.py +++ b/pkg/qqbot/cmds/session/prompt.py @@ -1,8 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context -import datetime +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="prompt", description="获取当前会话的前文", @@ -10,9 +9,9 @@ aliases=[], privilege=1 ) -class PromptCommand(AbstractCommandNode): +class PromptCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import pkg.openai.session session_name = ctx.session_name params = ctx.params diff --git a/pkg/qqbot/cmds/session/resend.py b/pkg/qqbot/cmds/session/resend.py index aa8b8032..c71b624d 100644 --- a/pkg/qqbot/cmds/session/resend.py +++ b/pkg/qqbot/cmds/session/resend.py @@ -1,8 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context -import datetime +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="resend", description="重新获取上一次问题的回复", @@ -10,20 +9,22 @@ aliases=[], privilege=1 ) -class ResendCommand(AbstractCommandNode): +class ResendCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: - import pkg.openai.session + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: + from ....openai import session as openai_session + from ....utils import context + from ....qqbot import message import config session_name = ctx.session_name reply = [] - session = pkg.openai.session.get_session(session_name) + session = openai_session.get_session(session_name) to_send = session.undo() - mgr = pkg.utils.context.get_qqbot_manager() + mgr = context.get_qqbot_manager() - reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config, + reply = message.process_normal_message(to_send, mgr, config, ctx.launcher_type, ctx.launcher_id, ctx.sender_id) diff --git a/pkg/qqbot/cmds/session/reset.py b/pkg/qqbot/cmds/session/reset.py index 87be5a9f..a93f3415 100644 --- a/pkg/qqbot/cmds/session/reset.py +++ b/pkg/qqbot/cmds/session/reset.py @@ -1,11 +1,11 @@ -from ..aamgr import AbstractCommandNode, Context import tips as tips_custom -import pkg.openai.session -import pkg.utils.context +from .. import aamgr +from ....openai import session +from ....utils import context -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name='reset', description='重置当前会话', @@ -13,21 +13,21 @@ aliases=[], privilege=1 ) -class ResetCommand(AbstractCommandNode): +class ResetCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: params = ctx.params session_name = ctx.session_name reply = "" if len(params) == 0: - pkg.openai.session.get_session(session_name).reset(explicit=True) + session.get_session(session_name).reset(explicit=True) reply = [tips_custom.command_reset_message] else: try: import pkg.openai.dprompt as dprompt - pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0]) + session.get_session(session_name).reset(explicit=True, use_prompt=params[0]) reply = [tips_custom.command_reset_name_message+"{}".format(dprompt.mode_inst().get_full_name(params[0]))] except Exception as e: reply = ["[bot]会话重置失败:{}".format(e)] diff --git a/pkg/qqbot/cmds/system/cconfig.py b/pkg/qqbot/cmds/system/cconfig.py index 149c3138..5b994262 100644 --- a/pkg/qqbot/cmds/system/cconfig.py +++ b/pkg/qqbot/cmds/system/cconfig.py @@ -1,6 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context import json +from .. import aamgr + def config_operation(cmd, params): reply = [] @@ -85,7 +86,7 @@ def config_operation(cmd, params): return reply -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="cfg", description="配置项管理", @@ -93,8 +94,8 @@ def config_operation(cmd, params): aliases=[], privilege=2 ) -class CfgCommand(AbstractCommandNode): +class CfgCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: return True, config_operation(ctx.command, ctx.params) \ No newline at end of file diff --git a/pkg/qqbot/cmds/system/cmd.py b/pkg/qqbot/cmds/system/cmd.py index c3e3fafe..40007588 100644 --- a/pkg/qqbot/cmds/system/cmd.py +++ b/pkg/qqbot/cmds/system/cmd.py @@ -1,7 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context, __command_list__ +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="cmd", description="显示指令列表", @@ -9,10 +9,10 @@ aliases=[], privilege=1 ) -class CmdCommand(AbstractCommandNode): +class CmdCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: - command_list = __command_list__ + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: + command_list = aamgr.__command_list__ reply = [] diff --git a/pkg/qqbot/cmds/system/help.py b/pkg/qqbot/cmds/system/help.py index 798cbf58..e4580990 100644 --- a/pkg/qqbot/cmds/system/help.py +++ b/pkg/qqbot/cmds/system/help.py @@ -1,7 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="help", description="显示自定义的帮助信息", @@ -9,9 +9,9 @@ aliases=[], privilege=1 ) -class HelpCommand(AbstractCommandNode): +class HelpCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import tips reply = ["[bot] "+tips.help_message + "\n请输入 !cmd 查看指令列表"] diff --git a/pkg/qqbot/cmds/system/reload.py b/pkg/qqbot/cmds/system/reload.py index b3836f2e..378dcef9 100644 --- a/pkg/qqbot/cmds/system/reload.py +++ b/pkg/qqbot/cmds/system/reload.py @@ -1,7 +1,9 @@ -from ..aamgr import AbstractCommandNode, Context import threading -@AbstractCommandNode.register( +from .. import aamgr + + +@aamgr.AbstractCommandNode.register( parent=None, name="reload", description="执行热重载", @@ -9,9 +11,9 @@ aliases=[], privilege=2 ) -class ReloadCommand(AbstractCommandNode): +class ReloadCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: reply = [] import pkg.utils.reloader diff --git a/pkg/qqbot/cmds/system/update.py b/pkg/qqbot/cmds/system/update.py index 6c95bb36..33a4df08 100644 --- a/pkg/qqbot/cmds/system/update.py +++ b/pkg/qqbot/cmds/system/update.py @@ -1,9 +1,10 @@ -from ..aamgr import AbstractCommandNode, Context import threading import traceback +from .. import aamgr -@AbstractCommandNode.register( + +@aamgr.AbstractCommandNode.register( parent=None, name="update", description="更新程序", @@ -11,9 +12,9 @@ aliases=[], privilege=2 ) -class UpdateCommand(AbstractCommandNode): +class UpdateCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: reply = [] import pkg.utils.updater import pkg.utils.reloader diff --git a/pkg/qqbot/cmds/system/usage.py b/pkg/qqbot/cmds/system/usage.py index 7a9c3faa..983c4de1 100644 --- a/pkg/qqbot/cmds/system/usage.py +++ b/pkg/qqbot/cmds/system/usage.py @@ -1,8 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context -import logging +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="usage", description="获取使用情况", @@ -10,9 +9,9 @@ aliases=[], privilege=1 ) -class UsageCommand(AbstractCommandNode): +class UsageCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: import config import pkg.utils.credit as credit import pkg.utils.context diff --git a/pkg/qqbot/cmds/system/version.py b/pkg/qqbot/cmds/system/version.py index 8d493c99..67bf3ef2 100644 --- a/pkg/qqbot/cmds/system/version.py +++ b/pkg/qqbot/cmds/system/version.py @@ -1,7 +1,7 @@ -from ..aamgr import AbstractCommandNode, Context +from .. import aamgr -@AbstractCommandNode.register( +@aamgr.AbstractCommandNode.register( parent=None, name="version", description="查看版本信息", @@ -9,9 +9,9 @@ aliases=[], privilege=1 ) -class VersionCommand(AbstractCommandNode): +class VersionCommand(aamgr.AbstractCommandNode): @classmethod - def process(cls, ctx: Context) -> tuple[bool, list]: + def process(cls, ctx: aamgr.Context) -> tuple[bool, list]: reply = [] import pkg.utils.updater diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py index 5b3e3ebf..414ffee4 100644 --- a/pkg/qqbot/command.py +++ b/pkg/qqbot/command.py @@ -1,23 +1,7 @@ # 指令处理模块 import logging -import json -import datetime -import os -import threading -import traceback - -import pkg.openai.session -import pkg.openai.manager -import pkg.utils.reloader -import pkg.utils.updater -import pkg.utils.context -import pkg.qqbot.message -import pkg.utils.credit as credit -# import pkg.qqbot.cmds.model as cmdmodel -import pkg.qqbot.cmds.aamgr as cmdmgr - -from mirai import Image +from ..qqbot.cmds import aamgr as cmdmgr def process_command(session_name: str, text_message: str, mgr, config, diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 2d8c1091..922de441 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -1,32 +1,25 @@ -import asyncio import json import os -import threading - +import logging from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ FriendMessage, Image, MessageChain, Plain -from func_timeout import func_set_timeout +import func_timeout -import pkg.openai.session -import pkg.openai.manager -from func_timeout import FunctionTimedOut -import logging +from ..openai import session as openai_session -import pkg.qqbot.filter -import pkg.qqbot.process as processor -import pkg.utils.context - -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models +from ..qqbot import filter as qqbot_filter +from ..qqbot import process as processor +from ..utils import context +from ..plugin import host as plugin_host +from ..plugin import models as plugin_models import tips as tips_custom - -import pkg.qqbot.adapter as msadapter +from ..qqbot import adapter as msadapter # 检查消息是否符合泛响应匹配机制 def check_response_rule(group_id:int, text: str): - config = pkg.utils.context.get_config() + config = context.get_config() rules = config.response_rules @@ -55,7 +48,7 @@ def check_response_rule(group_id:int, text: str): def response_at(group_id: int): - config = pkg.utils.context.get_config() + config = context.get_config() use_response_rule = config.response_rules @@ -73,7 +66,7 @@ def response_at(group_id: int): def random_responding(group_id): - config = pkg.utils.context.get_config() + config = context.get_config() use_response_rule = config.response_rules @@ -130,10 +123,10 @@ def __init__(self, first_time_init=True): self.adapter = NakuruProjectAdapter(config.nakuru_config) self.bot_account_id = self.adapter.bot_account_id else: - self.adapter = pkg.utils.context.get_qqbot_manager().adapter - self.bot_account_id = pkg.utils.context.get_qqbot_manager().bot_account_id + self.adapter = context.get_qqbot_manager().adapter + self.bot_account_id = context.get_qqbot_manager().bot_account_id - pkg.utils.context.set_qqbot_manager(self) + context.set_qqbot_manager(self) # 注册诸事件 # Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码 @@ -154,7 +147,7 @@ def friend_message_handler(): self.on_person_message(event) - pkg.utils.context.get_thread_ctl().submit_user_task( + context.get_thread_ctl().submit_user_task( friend_message_handler, ) self.adapter.register_listener( @@ -179,7 +172,7 @@ def stranger_message_handler(): self.on_person_message(event) - pkg.utils.context.get_thread_ctl().submit_user_task( + context.get_thread_ctl().submit_user_task( stranger_message_handler, ) # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 @@ -206,7 +199,7 @@ def group_message_handler(event: GroupMessage): self.on_group_message(event) - pkg.utils.context.get_thread_ctl().submit_user_task( + context.get_thread_ctl().submit_user_task( group_message_handler, event ) @@ -250,22 +243,22 @@ def unsubscribe_all(): if hasattr(banlist, "enable_group"): self.enable_group = banlist.enable_group - config = pkg.utils.context.get_config() + config = context.get_config() if os.path.exists("sensitive.json") \ and config.sensitive_word_filter is not None \ and config.sensitive_word_filter: with open("sensitive.json", "r", encoding="utf-8") as f: sensitive_json = json.load(f) - self.reply_filter = pkg.qqbot.filter.ReplyFilter( + self.reply_filter = qqbot_filter.ReplyFilter( sensitive_words=sensitive_json['words'], mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*', mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else '' ) else: - self.reply_filter = pkg.qqbot.filter.ReplyFilter([]) + self.reply_filter = qqbot_filter.ReplyFilter([]) def send(self, event, msg, check_quote=True, check_at_sender=True): - config = pkg.utils.context.get_config() + config = context.get_config() if check_at_sender and config.at_sender: msg.insert( @@ -306,7 +299,7 @@ def on_person_message(self, event: MessageEvent): for i in range(self.retry): try: - @func_set_timeout(config.process_message_timeout) + @func_timeout.func_set_timeout(config.process_message_timeout) def time_ctrl_wrapper(): reply = processor.process_message('person', event.sender.id, str(event.message_chain), event.message_chain, @@ -315,16 +308,16 @@ def time_ctrl_wrapper(): reply = time_ctrl_wrapper() break - except FunctionTimedOut: + except func_timeout.FunctionTimedOut: logging.warning("person_{}: 超时,重试中({})".format(event.sender.id, i)) - pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock() - if "person_{}".format(event.sender.id) in pkg.qqbot.process.processing: - pkg.qqbot.process.processing.remove('person_{}'.format(event.sender.id)) + openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock() + if "person_{}".format(event.sender.id) in processor.processing: + processor.processing.remove('person_{}'.format(event.sender.id)) failed += 1 continue if failed == self.retry: - pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock() + openai_session.get_session('person_{}'.format(event.sender.id)).release_response_lock() self.notify_admin("{} 请求超时".format("person_{}".format(event.sender.id))) reply = [tips_custom.reply_message] @@ -344,7 +337,7 @@ def process(text=None) -> str: failed = 0 for i in range(self.retry): try: - @func_set_timeout(config.process_message_timeout) + @func_timeout.func_set_timeout(config.process_message_timeout) def time_ctrl_wrapper(): replys = processor.process_message('group', event.group.id, str(event.message_chain).strip() if text is None else text, @@ -354,16 +347,16 @@ def time_ctrl_wrapper(): replys = time_ctrl_wrapper() break - except FunctionTimedOut: + except func_timeout.FunctionTimedOut: logging.warning("group_{}: 超时,重试中({})".format(event.group.id, i)) - pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock() - if "group_{}".format(event.group.id) in pkg.qqbot.process.processing: - pkg.qqbot.process.processing.remove('group_{}'.format(event.group.id)) + openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock() + if "group_{}".format(event.group.id) in processor.processing: + processor.processing.remove('group_{}'.format(event.group.id)) failed += 1 continue if failed == self.retry: - pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock() + openai_session.get_session('group_{}'.format(event.group.id)).release_response_lock() self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id))) replys = [tips_custom.replys_message] @@ -392,7 +385,7 @@ def time_ctrl_wrapper(): # 通知系统管理员 def notify_admin(self, message: str): - config = pkg.utils.context.get_config() + config = context.get_config() if config.admin_qq != 0 and config.admin_qq != []: logging.info("通知管理员:{}".format(message)) if type(config.admin_qq) == int: @@ -410,7 +403,7 @@ def notify_admin(self, message: str): ) def notify_admin_message_chain(self, message): - config = pkg.utils.context.get_config() + config = context.get_config() if config.admin_qq != 0 and config.admin_qq != []: logging.info("通知管理员:{}".format(message)) if type(config.admin_qq) == int: diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py index 131805f2..c6058abd 100644 --- a/pkg/qqbot/message.py +++ b/pkg/qqbot/message.py @@ -1,19 +1,20 @@ # 普通消息处理模块 import logging + import openai -import pkg.utils.context -import pkg.openai.session -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models -import pkg.qqbot.blob as blob +from ..utils import context +from ..openai import session as openai_session + +from ..plugin import host as plugin_host +from ..plugin import models as plugin_models import tips as tips_custom def handle_exception(notify_admin: str = "", set_reply: str = "") -> list: """处理异常,当notify_admin不为空时,会通知管理员,返回通知用户的消息""" import config - pkg.utils.context.get_qqbot_manager().notify_admin(notify_admin) + context.get_qqbot_manager().notify_admin(notify_admin) if config.hide_exce_info_to_user: return [tips_custom.alter_tip_message] if tips_custom.alter_tip_message else [] else: @@ -26,7 +27,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, logging.info("[{}]发送消息:{}".format(session_name, text_message[:min(20, len(text_message))] + ( "..." if len(text_message) > 20 else ""))) - session = pkg.openai.session.get_session(session_name) + session = openai_session.get_session(session_name) unexpected_exception_times = 0 @@ -54,7 +55,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, "funcs_called": funcs, } - event = pkg.plugin.host.emit(plugin_models.NormalMessageResponded, **args) + event = plugin_host.emit(plugin_models.NormalMessageResponded, **args) if event.get_return_value("prefix") is not None: prefix = event.get_return_value("prefix") @@ -78,29 +79,29 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, if 'message' in e.error and e.error['message'].__contains__('You exceeded your current quota'): # 尝试切换api-key - current_key_name = pkg.utils.context.get_openai_manager().key_mgr.get_key_name( - pkg.utils.context.get_openai_manager().key_mgr.using_key + current_key_name = context.get_openai_manager().key_mgr.get_key_name( + context.get_openai_manager().key_mgr.using_key ) - pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded() + context.get_openai_manager().key_mgr.set_current_exceeded() # 触发插件事件 args = { 'key_name': current_key_name, - 'usage': pkg.utils.context.get_openai_manager().audit_mgr - .get_usage(pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()), - 'exceeded_keys': pkg.utils.context.get_openai_manager().key_mgr.exceeded, + 'usage': context.get_openai_manager().audit_mgr + .get_usage(context.get_openai_manager().key_mgr.get_using_key_md5()), + 'exceeded_keys': context.get_openai_manager().key_mgr.exceeded, } event = plugin_host.emit(plugin_models.KeyExceeded, **args) if not event.is_prevented_default(): - switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch() + switched, name = context.get_openai_manager().key_mgr.auto_switch() if not switched: reply = handle_exception( "api-key调用额度超限({}),无可用api_key,请向OpenAI账户充值或在config.py中更换api_key;如果你认为这是误判,请尝试重启程序。".format( current_key_name), "[bot]err:API调用额度超额,请联系管理员,或等待修复") else: - openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key() + openai.api_key = context.get_openai_manager().key_mgr.get_using_key() mgr.notify_admin("api-key调用额度超限({}),接口报错,已切换到{}".format(current_key_name, name)) reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"] continue diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 3ae95437..ed45b37f 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -5,28 +5,22 @@ import mirai import logging -from mirai import MessageChain, Plain - # 这里不使用动态引入config # 因为在这里动态引入会卡死程序 # 而此模块静态引用config与动态引入的表现一致 # 已弃用,由于超时时间现已动态使用 # import config as config_init_import -import pkg.openai.session -import pkg.openai.manager -import pkg.utils.reloader -import pkg.utils.updater -import pkg.utils.context -import pkg.qqbot.message -import pkg.qqbot.command -import pkg.qqbot.ratelimit as ratelimit - -import pkg.plugin.host as plugin_host -import pkg.plugin.models as plugin_models -import pkg.qqbot.ignore as ignore -import pkg.qqbot.banlist as banlist -import pkg.qqbot.blob as blob +from ..qqbot import ratelimit +from ..qqbot import command, message +from ..openai import session as openai_session +from ..utils import context + +from ..plugin import host as plugin_host +from ..plugin import models as plugin_models +from ..qqbot import ignore +from ..qqbot import banlist +from ..qqbot import blob import tips as tips_custom processing = [] @@ -41,11 +35,11 @@ def is_admin(qq: int) -> bool: return qq == config.admin_qq -def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain, - sender_id: int) -> MessageChain: +def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: mirai.MessageChain, + sender_id: int) -> mirai.MessageChain: global processing - mgr = pkg.utils.context.get_qqbot_manager() + mgr = context.get_qqbot_manager() reply = [] session_name = "{}_{}".format(launcher_type, launcher_id) @@ -62,7 +56,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes import config if not config.wait_last_done and session_name in processing: - return MessageChain([Plain(tips_custom.message_drop_tip)]) + return mirai.MessageChain([mirai.Plain(tips_custom.message_drop_tip)]) # 检查是否被禁言 if launcher_type == 'group': @@ -74,9 +68,9 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes import config if config.income_msg_check: if mgr.reply_filter.is_illegal(text_message): - return MessageChain(Plain("[bot] 消息中存在不合适的内容, 请更换措辞")) + return mirai.MessageChain(mirai.Plain("[bot] 消息中存在不合适的内容, 请更换措辞")) - pkg.openai.session.get_session(session_name).acquire_response_lock() + openai_session.get_session(session_name).acquire_response_lock() text_message = text_message.strip() @@ -87,7 +81,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes # 处理消息 try: - config = pkg.utils.context.get_config() + config = context.get_config() processing.append(session_name) try: @@ -114,7 +108,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply = event.get_return_value("reply") if not event.is_prevented_default(): - reply = pkg.qqbot.command.process_command(session_name, text_message, + reply = command.process_command(session_name, text_message, mgr, config, launcher_type, launcher_id, sender_id, is_admin(sender_id)) else: # 消息 @@ -124,7 +118,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes if ratelimit.is_reach_limit(session_name): logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) - return MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else [] + return mirai.MessageChain(["[bot]"+tips_custom.rate_limit_drop_tip]) if tips_custom.rate_limit_drop_tip != "" else [] before = time.time() # 触发插件事件 @@ -146,7 +140,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply = event.get_return_value("reply") if not event.is_prevented_default(): - reply = pkg.qqbot.message.process_normal_message(text_message, + reply = message.process_normal_message(text_message, mgr, config, launcher_type, launcher_id, sender_id) # 限速等待时间 @@ -170,7 +164,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes finally: processing.remove(session_name) finally: - pkg.openai.session.get_session(session_name).release_response_lock() + openai_session.get_session(session_name).release_response_lock() # 检查延迟时间 if config.force_delay_range[1] == 0: @@ -191,4 +185,4 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes logging.info("[风控] 强制延迟{:.2f}秒(如需关闭,请到config.py修改force_delay_range字段)".format(delay_time)) time.sleep(delay_time) - return MessageChain(reply) + return mirai.MessageChain(reply) diff --git a/pkg/qqbot/sources/nakuru.py b/pkg/qqbot/sources/nakuru.py index 3f70b4b8..51e5e41b 100644 --- a/pkg/qqbot/sources/nakuru.py +++ b/pkg/qqbot/sources/nakuru.py @@ -1,19 +1,18 @@ -import mirai - -from ..adapter import MessageSourceAdapter, MessageConverter, EventConverter -import nakuru -import nakuru.entities.components as nkc - import asyncio import typing import traceback import logging -import json -from pkg.qqbot.blob import Forward, ForwardMessageNode, ForwardMessageDiaplay +import mirai + +import nakuru +import nakuru.entities.components as nkc + +from .. import adapter as adapter_model +from ...qqbot import blob -class NakuruProjectMessageConverter(MessageConverter): +class NakuruProjectMessageConverter(adapter_model.MessageConverter): """消息转换器""" @staticmethod def yiri2target(message_chain: mirai.MessageChain) -> list: @@ -49,7 +48,7 @@ def yiri2target(message_chain: mirai.MessageChain) -> list: nakuru_msg_list.append(nkc.Record.fromURL(component.url)) elif component.path is not None: nakuru_msg_list.append(nkc.Record.fromFileSystem(component.path)) - elif type(component) is Forward: + elif type(component) is blob.Forward: # 转发消息 yiri_forward_node_list = component.node_list nakuru_forward_node_list = [] @@ -102,7 +101,7 @@ def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.Messag return chain -class NakuruProjectEventConverter(EventConverter): +class NakuruProjectEventConverter(adapter_model.EventConverter): """事件转换器""" @staticmethod def yiri2target(event: typing.Type[mirai.Event]): @@ -157,7 +156,7 @@ def target2yiri(event: typing.Any) -> mirai.Event: raise Exception("未支持转换的事件类型: " + str(event)) -class NakuruProjectAdapter(MessageSourceAdapter): +class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): """nakuru-project适配器""" bot: nakuru.CQHTTP bot_account_id: int diff --git a/pkg/qqbot/sources/yirimirai.py b/pkg/qqbot/sources/yirimirai.py index 570c55a1..7828be18 100644 --- a/pkg/qqbot/sources/yirimirai.py +++ b/pkg/qqbot/sources/yirimirai.py @@ -1,13 +1,14 @@ -from ..adapter import MessageSourceAdapter +import asyncio +import typing + import mirai import mirai.models.bus from mirai.bot import MiraiRunner -import asyncio -import typing +from .. import adapter as adapter_model -class YiriMiraiAdapter(MessageSourceAdapter): +class YiriMiraiAdapter(adapter_model.MessageSourceAdapter): """YiriMirai适配器""" bot: mirai.Mirai diff --git a/pkg/utils/context.py b/pkg/utils/context.py index 2f8dee44..0da18228 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -1,5 +1,5 @@ import threading -from pkg.utils import ThreadCtl +from . import threadctl context = { @@ -87,8 +87,8 @@ def set_thread_ctl(inst): context_lock.release() -def get_thread_ctl() -> ThreadCtl: +def get_thread_ctl() -> threadctl.ThreadCtl: context_lock.acquire() - t: ThreadCtl = context['pool_ctl'] + t: threadctl.ThreadCtl = context['pool_ctl'] context_lock.release() return t diff --git a/pkg/utils/pkgmgr.py b/pkg/utils/pkgmgr.py index 42dc67bf..741c8f48 100644 --- a/pkg/utils/pkgmgr.py +++ b/pkg/utils/pkgmgr.py @@ -1,6 +1,6 @@ from pip._internal import main as pipmain -import pkg.utils.log as log +from . import log def install(package): diff --git a/pkg/utils/reloader.py b/pkg/utils/reloader.py index f116e088..a9f7445b 100644 --- a/pkg/utils/reloader.py +++ b/pkg/utils/reloader.py @@ -1,10 +1,9 @@ import logging -import threading - import importlib import pkgutil -import pkg.utils.context as context -import pkg.plugin.host + +from . import context +from ..plugin import host as plugin_host def walk(module, prefix='', path_prefix=''): @@ -15,7 +14,7 @@ def walk(module, prefix='', path_prefix=''): walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/') else: logging.info('reload module: {}, path: {}'.format(prefix + item.name, path_prefix + item.name + '.py')) - pkg.plugin.host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py' + plugin_host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py' importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=[''])) diff --git a/pkg/utils/text2img.py b/pkg/utils/text2img.py index 68df0d68..1da2afca 100644 --- a/pkg/utils/text2img.py +++ b/pkg/utils/text2img.py @@ -1,11 +1,11 @@ import logging - -from PIL import Image, ImageDraw, ImageFont import re import os import config import traceback +from PIL import Image, ImageDraw, ImageFont + text_render_font: ImageFont = None if config.blob_message_strategy == "image": # 仅在启用了image时才加载字体 diff --git a/pkg/utils/updater.py b/pkg/utils/updater.py index 956a36bb..e3c1714d 100644 --- a/pkg/utils/updater.py +++ b/pkg/utils/updater.py @@ -3,10 +3,9 @@ import os.path import requests -import json -import pkg.utils.constants -import pkg.utils.network as network +from . import constants +from . import network def check_dulwich_closure(): @@ -70,7 +69,7 @@ def get_release_list() -> list: def get_current_tag() -> str: """获取当前tag""" - current_tag = pkg.utils.constants.semantic_version + current_tag = constants.semantic_version if os.path.exists("current_tag"): with open("current_tag", "r") as f: current_tag = f.read()