From 6170f82349e7b678e0cc1ef105ed2621e04e2291 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 31 Jul 2024 23:38:31 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BC=98=E5=8C=96=E5=88=9D=E5=A7=8B?= =?UTF-8?q?=E5=8C=96=E3=80=81=E6=B6=88=E6=81=AF=E5=A4=84=E7=90=86=E6=97=B6?= =?UTF-8?q?=E7=9A=84=E9=85=8D=E7=BD=AE=E8=AF=BB=E5=8F=96=E8=BF=87=E7=A8=8B?= =?UTF-8?q?=EF=BC=8C=E5=87=8F=E5=B0=91=E6=80=A7=E8=83=BD=E6=8D=9F=E8=80=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/bootstrap.py | 26 ++++-- astrbot/message/handler.py | 2 +- dashboard/server.py | 3 +- model/command/internal_handler.py | 4 +- model/platform/qq_aiocqhttp.py | 6 +- model/platform/qq_nakuru.py | 8 +- model/platform/qq_official.py | 6 +- model/provider/openai_official.py | 6 +- type/config.py | 76 ++++++++++++++++- type/types.py | 11 +-- util/cmd_config.py | 67 ++++++++------- util/config_utils.py | 133 +----------------------------- 12 files changed, 157 insertions(+), 191 deletions(-) diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index 62160c3b..0facf203 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -10,6 +10,7 @@ from model.platform.manager import PlatformManager from typing import Dict, List, Union from type.types import Context +from type.config import VERSION from SparkleLogging.utils.core import LogManager from logging import Logger from util.cmd_config import CmdConfig @@ -23,15 +24,26 @@ class AstrBotBootstrap(): def __init__(self) -> None: self.context = Context() - self.config_helper: CmdConfig = CmdConfig() + self.config_helper = CmdConfig() # load configs and ensure the backward compatibility - init_configs() try_migrate_config() - self.configs = inject_to_context(self.context) - logger.info("AstrBot v" + self.context.version) self.context.config_helper = self.config_helper + self.context.base_config = self.config_helper.cached_config + self.context.default_personality = { + "name": "default", + "prompt": self.context.base_config.get("default_personality_str", ""), + } + self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False) + nick_qq = self.context.base_config.get("nick_qq", ('/', '!')) + if isinstance(nick_qq, str): nick_qq = (nick_qq, ) + self.context.nick = nick_qq + self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True) + self.context.version = VERSION + + logger.info("AstrBot v" + self.context.version) + # apply proxy settings http_proxy = self.context.base_config.get("http_proxy") https_proxy = self.context.base_config.get("https_proxy") @@ -93,9 +105,9 @@ async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]): await asyncio.sleep(5) def load_llm(self): - if 'openai' in self.configs and \ - len(self.configs['openai']['key']) and \ - self.configs['openai']['key'][0] is not None: + if 'openai' in self.config_helper.cached_config and \ + len(self.config_helper.cached_config['openai']['key']) and \ + self.config_helper.cached_config['openai']['key'][0] is not None: from model.provider.openai_official import ProviderOpenAIOfficial from model.command.openai_official_handler import OpenAIOfficialCommandHandler self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py index 59ac1d11..b5fafc1e 100644 --- a/astrbot/message/handler.py +++ b/astrbot/message/handler.py @@ -114,7 +114,7 @@ def __init__(self, context: Context, self.llm_wake_prefix = self.context.base_config['llm_wake_prefix'] self.nicks = self.context.nick self.provider = provider - self.reply_prefix = self.context.reply_prefix + self.reply_prefix = str(self.context.reply_prefix) def set_provider(self, provider: Provider): self.provider = provider diff --git a/dashboard/server.py b/dashboard/server.py index d3ba1f1e..d4a3681c 100644 --- a/dashboard/server.py +++ b/dashboard/server.py @@ -86,12 +86,11 @@ def authenticate(): @self.dashboard_be.post("/api/change_password") def change_password(): - password = self.context.base_config("dashboard_password", "") + password = self.context.base_config.get("dashboard_password", "") # 获得请求体 post_data = request.json if post_data["password"] == password: self.context.config_helper.put("dashboard_password", post_data["new_password"]) - self.context.base_config['dashboard_password'] = post_data["new_password"] return Response( status="success", message="修改成功。", diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index 2fad5ef2..8b93713b 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -230,15 +230,17 @@ def web_search(self, message: AstrMessageEvent, context: Context): ) def t2i_toggle(self, message: AstrMessageEvent, context: Context): - p = context.config_helper.get("qq_pic_mode", True) + p = context.t2i_mode if p: context.config_helper.put("qq_pic_mode", False) + context.t2i_mode = False return CommandResult( hit=True, success=True, message_chain="已关闭文本转图片模式。", ) context.config_helper.put("qq_pic_mode", True) + context.t2i_mode = True return CommandResult( hit=True, diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 8c82db7f..4abdfa89 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -117,8 +117,8 @@ async def handle_msg(self, message: AstrBotMessage): # 解析 role sender_id = str(message.sender.user_id) - if sender_id == self.context.config_helper.get('admin_qq', '') or \ - sender_id in self.context.config_helper.get('other_admins', []): + if sender_id == self.context.base_config.get('admin_qq', '') or \ + sender_id in self.context.base_config.get('other_admins', []): role = 'admin' else: role = 'member' @@ -154,7 +154,7 @@ async def reply_msg(self, res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list): + if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py index 58a8b3d2..d4052094 100644 --- a/model/platform/qq_nakuru.py +++ b/model/platform/qq_nakuru.py @@ -112,8 +112,8 @@ async def handle_msg(self, message: AstrBotMessage): # 解析 role sender_id = str(message.raw_message.user_id) - if sender_id == self.context.config_helper.get('admin_qq', '') or \ - sender_id in self.context.config_helper.get('other_admins', []): + if sender_id == self.context.base_config.get('admin_qq', '') or \ + sender_id in self.context.base_config.get('other_admins', []): role = 'admin' else: role = 'member' @@ -152,7 +152,7 @@ async def reply_msg(self, res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list): + if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: @@ -186,7 +186,7 @@ async def _reply(self, source, message_chain: List[BaseMessageComponent]): plain_text_len += len(i.text) elif isinstance(i, Image): image_num += 1 - if plain_text_len > self.context.config_helper.get('qq_forward_threshold', 200): + if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200): # 删除At for i in message_chain: if isinstance(i, At): diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index bd3afd36..5ca3e301 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -209,8 +209,8 @@ async def handle_msg(self, message: AstrBotMessage): # 解析出 role sender_id = message.sender.user_id - if sender_id == self.context.config_helper.get('admin_qqchan', None) or \ - sender_id in self.context.config_helper.get('other_admins', None): + if sender_id == self.context.base_config.get('admin_qqchan', None) or \ + sender_id in self.context.base_config.get('other_admins', None): role = 'admin' else: role = 'member' @@ -249,7 +249,7 @@ async def reply_msg(self, msg_ref = None rendered_images = [] - if self.context.config_helper.get("qq_pic_mode", False) and isinstance(result_message, list): + if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list): rendered_images = await self.convert_to_t2i_chain(result_message) if isinstance(result_message, list): diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 73e5baaf..ac156cdb 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -53,7 +53,7 @@ def __init__(self, context: Context) -> None: os.makedirs("data/openai", exist_ok=True) - self.cc = CmdConfig + self.context = context self.key_data_path = "data/openai/keys.json" self.api_keys = [] self.chosen_api_key = None @@ -78,7 +78,7 @@ def __init__(self, context: Context) -> None: ) self.model_configs: Dict = cfg['chatGPTConfigs'] super().set_curr_model(self.model_configs['model']) - self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None) + self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None) self.session_memory: Dict[str, List] = {} # 会话记忆 self.session_memory_lock = threading.Lock() self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 @@ -492,7 +492,7 @@ def dump_contexts_page(self, session_id: str, size=5, page=1,): def set_model(self, model: str): self.model_configs['model'] = model - self.cc.put_by_dot_str("openai.chatGPTConfigs.model", model) + self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model) super().set_curr_model(model) def get_configs(self): diff --git a/type/config.py b/type/config.py index 089f059b..e13bc926 100644 --- a/type/config.py +++ b/type/config.py @@ -1 +1,75 @@ -VERSION = '3.3.5' \ No newline at end of file +VERSION = '3.3.7' + +DEFAULT_CONFIG = { + "qqbot": { + "enable": False, + "appid": "", + "token": "", + }, + "gocqbot": { + "enable": False, + }, + "uniqueSessionMode": False, + "dump_history_interval": 10, + "limit": { + "time": 60, + "count": 30, + }, + "notice": "", + "direct_message_mode": True, + "reply_prefix": "", + "baidu_aip": { + "enable": False, + "app_id": "", + "api_key": "", + "secret_key": "" + }, + "openai": { + "key": [], + "api_base": "", + "chatGPTConfigs": { + "model": "gpt-4o", + "max_tokens": 6000, + "temperature": 0.9, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + }, + "total_tokens_limit": 10000, + }, + "qq_forward_threshold": 200, + "qq_welcome": "", + "qq_pic_mode": True, + "gocq_host": "127.0.0.1", + "gocq_http_port": 5700, + "gocq_websocket_port": 6700, + "gocq_react_group": True, + "gocq_react_guild": True, + "gocq_react_friend": True, + "gocq_react_group_increase": True, + "other_admins": [], + "CHATGPT_BASE_URL": "", + "qqbot_secret": "", + "qqofficial_enable_group_message": False, + "admin_qq": "", + "nick_qq": ["/", "!"], + "admin_qqchan": "", + "llm_env_prompt": "", + "llm_wake_prefix": "", + "default_personality_str": "", + "openai_image_generate": { + "model": "dall-e-3", + "size": "1024x1024", + "style": "vivid", + "quality": "standard", + }, + "http_proxy": "", + "https_proxy": "", + "dashboard_username": "", + "dashboard_password": "", + "aiocqhttp": { + "enable": False, + "ws_reverse_host": "", + "ws_reverse_port": 0, + } +} \ No newline at end of file diff --git a/type/types.py b/type/types.py index e34515a4..3d733189 100644 --- a/type/types.py +++ b/type/types.py @@ -28,21 +28,22 @@ def __init__(self): self.unique_session = False # 独立会话 self.version: str = None # 机器人版本 - self.nick = None # gocq 的唤醒词 - self.stat = {} + self.nick: tuple = None # gocq 的唤醒词 self.t2i_mode = False self.web_search = False # 是否开启了网页搜索 - self.reply_prefix = "" + + self.metrics_uploader = None self.updator: AstrBotUpdator = None self.plugin_updator: PluginUpdator = None - self.metrics_uploader = None - self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins) self.image_renderer = TextToImageRenderer() self.image_uploader = ImageUploader() self.message_handler = None # see astrbot/message/handler.py self.ext_tasks: List[Task] = [] + # useless + self.reply_prefix = "" + def register_commands(self, plugin_name: str, command_name: str, diff --git a/util/cmd_config.py b/util/cmd_config.py index 337534bb..80cac42a 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -1,19 +1,31 @@ import os import json -from typing import Union +from type.config import DEFAULT_CONFIG cpath = "data/cmd_config.json" def check_exist(): if not os.path.exists(cpath): with open(cpath, "w", encoding="utf-8-sig") as f: - json.dump({}, f, indent=4, ensure_ascii=False) + json.dump({}, f, ensure_ascii=False) f.flush() class CmdConfig(): + def __init__(self) -> None: + self.cached_config: dict = {} + self.init_configs() + + def init_configs(self): + ''' + 初始化必需的配置项 + ''' + self.init_config_items(DEFAULT_CONFIG) @staticmethod def get(key, default=None): + ''' + 从文件系统中直接获取配置 + ''' check_exist() with open(cpath, "r", encoding="utf-8-sig") as f: d = json.load(f) @@ -22,28 +34,33 @@ def get(key, default=None): else: return default - @staticmethod - def get_all(): + def get_all(self): + ''' + 从文件系统中获取所有配置 + ''' check_exist() with open(cpath, "r", encoding="utf-8-sig") as f: - return json.load(f) + conf_str = f.read() + if conf_str.startswith(u'/ufeff'): # remove BOM + conf_str = conf_str.encode('utf8')[3:].decode('utf8') + conf = json.loads(conf_str) + return conf - @staticmethod - def put(key, value): - check_exist() + def put(self, key, value): with open(cpath, "r", encoding="utf-8-sig") as f: d = json.load(f) d[key] = value with open(cpath, "w", encoding="utf-8-sig") as f: - json.dump(d, f, indent=4, ensure_ascii=False) + json.dump(d, f, indent=2, ensure_ascii=False) f.flush() + self.cached_config[key] = value + @staticmethod def put_by_dot_str(key: str, value): ''' 根据点分割的字符串,将值写入配置文件 ''' - check_exist() with open(cpath, "r", encoding="utf-8-sig") as f: d = json.load(f) _d = d @@ -54,30 +71,22 @@ def put_by_dot_str(key: str, value): else: _d = _d[_ks[i]] with open(cpath, "w", encoding="utf-8-sig") as f: - json.dump(d, f, indent=4, ensure_ascii=False) + json.dump(d, f, indent=2, ensure_ascii=False) f.flush() - @staticmethod - def init_attributes(key: Union[str, list], init_val=""): - check_exist() - conf_str = '' - with open(cpath, "r", encoding="utf-8-sig") as f: - conf_str = f.read() - if conf_str.startswith(u'/ufeff'): - conf_str = conf_str.encode('utf8')[3:].decode('utf8') - d = json.loads(conf_str) + def init_config_items(self, d: dict): + conf = self.get_all() + + if not self.cached_config: + self.cached_config = conf + _tag = False - if isinstance(key, str): - if key not in d: - d[key] = init_val + for key, val in d.items(): + if key not in conf: + conf[key] = val _tag = True - elif isinstance(key, list): - for k in key: - if k not in d: - d[k] = init_val - _tag = True if _tag: with open(cpath, "w", encoding="utf-8-sig") as f: - json.dump(d, f, indent=4, ensure_ascii=False) + json.dump(conf, f, indent=2, ensure_ascii=False) f.flush() diff --git a/util/config_utils.py b/util/config_utils.py index cdf783e3..8fa9a1c4 100644 --- a/util/config_utils.py +++ b/util/config_utils.py @@ -1,89 +1,5 @@ import json, os from util.cmd_config import CmdConfig -from type.config import VERSION -from type.types import Context - -def init_configs(): - ''' - 初始化必需的配置项 - ''' - cc = CmdConfig() - - cc.init_attributes("qqbot", { - "enable": False, - "appid": "", - "token": "", - }) - cc.init_attributes("gocqbot", { - "enable": False, - }) - cc.init_attributes("uniqueSessionMode", False) - cc.init_attributes("dump_history_interval", 10) - cc.init_attributes("limit", { - "time": 60, - "count": 30, - }) - cc.init_attributes("notice", "") - cc.init_attributes("direct_message_mode", True) - cc.init_attributes("reply_prefix", "") - cc.init_attributes("baidu_aip", { - "enable": False, - "app_id": "", - "api_key": "", - "secret_key": "" - }) - cc.init_attributes("openai", { - "key": [], - "api_base": "", - "chatGPTConfigs": { - "model": "gpt-4o", - "max_tokens": 6000, - "temperature": 0.9, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - }, - "total_tokens_limit": 10000, - }) - - - cc.init_attributes("qq_forward_threshold", 200) - cc.init_attributes("qq_welcome", "") - cc.init_attributes("qq_pic_mode", True) - cc.init_attributes("gocq_host", "127.0.0.1") - cc.init_attributes("gocq_http_port", 5700) - cc.init_attributes("gocq_websocket_port", 6700) - cc.init_attributes("gocq_react_group", True) - cc.init_attributes("gocq_react_guild", True) - cc.init_attributes("gocq_react_friend", True) - cc.init_attributes("gocq_react_group_increase", True) - cc.init_attributes("other_admins", []) - cc.init_attributes("CHATGPT_BASE_URL", "") - cc.init_attributes("qqbot_secret", "") - cc.init_attributes("qqofficial_enable_group_message", False) - cc.init_attributes("admin_qq", "") - cc.init_attributes("nick_qq", ["!", "!", "ai"]) - cc.init_attributes("admin_qqchan", "") - cc.init_attributes("llm_env_prompt", "") - cc.init_attributes("llm_wake_prefix", "") - cc.init_attributes("default_personality_str", "") - cc.init_attributes("openai_image_generate", { - "model": "dall-e-3", - "size": "1024x1024", - "style": "vivid", - "quality": "standard", - }) - cc.init_attributes("http_proxy", "") - cc.init_attributes("https_proxy", "") - cc.init_attributes("dashboard_username", "") - cc.init_attributes("dashboard_password", "") - - # aiocqhttp 适配器 - cc.init_attributes("aiocqhttp", { - "enable": False, - "ws_reverse_host": "", - "ws_reverse_port": 0, - }) def try_migrate_config(): ''' @@ -97,51 +13,4 @@ def try_migrate_config(): try: os.remove("cmd_config.json") except Exception as e: - pass - -def inject_to_context(context: Context): - ''' - 将配置注入到 Context 中。 - this method returns all the configs - ''' - cc = CmdConfig() - - context.version = VERSION - context.base_config = cc.get_all() - - cfg = context.base_config - - if 'reply_prefix' in cfg: - # 适配旧版配置 - if isinstance(cfg['reply_prefix'], dict): - context.reply_prefix = "" - cfg['reply_prefix'] = "" - cc.put("reply_prefix", "") - else: - context.reply_prefix = cfg['reply_prefix'] - - default_personality_str = cc.get("default_personality_str", "") - if default_personality_str == "": - context.default_personality = None - else: - context.default_personality = { - "name": "default", - "prompt": default_personality_str, - } - - if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']: - context.unique_session = True - else: - context.unique_session = False - - nick_qq = cc.get("nick_qq", None) - if nick_qq == None: - nick_qq = ("/", ) - if isinstance(nick_qq, str): - nick_qq = (nick_qq, ) - if isinstance(nick_qq, list): - nick_qq = tuple(nick_qq) - context.nick = nick_qq - context.t2i_mode = cc.get("qq_pic_mode", True) - - return cfg \ No newline at end of file + pass \ No newline at end of file