Skip to content

Commit

Permalink
fix: 优化初始化、消息处理时的配置读取过程,减少性能损耗
Browse files Browse the repository at this point in the history
  • Loading branch information
Soulter committed Jul 31, 2024
1 parent 114a6a8 commit 6170f82
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 191 deletions.
26 changes: 19 additions & 7 deletions astrbot/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from model.platform.manager import PlatformManager
from typing import Dict, List, Union
from type.types import Context
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from util.cmd_config import CmdConfig
Expand All @@ -23,15 +24,26 @@
class AstrBotBootstrap():
def __init__(self) -> None:
self.context = Context()
self.config_helper: CmdConfig = CmdConfig()
self.config_helper = CmdConfig()

# load configs and ensure the backward compatibility
init_configs()
try_migrate_config()
self.configs = inject_to_context(self.context)
logger.info("AstrBot v" + self.context.version)
self.context.config_helper = self.config_helper
self.context.base_config = self.config_helper.cached_config

self.context.default_personality = {
"name": "default",
"prompt": self.context.base_config.get("default_personality_str", ""),
}
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
self.context.nick = nick_qq
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
self.context.version = VERSION

logger.info("AstrBot v" + self.context.version)

# apply proxy settings
http_proxy = self.context.base_config.get("http_proxy")
https_proxy = self.context.base_config.get("https_proxy")
Expand Down Expand Up @@ -93,9 +105,9 @@ async def handle_task(self, task: Union[asyncio.Task, asyncio.Future]):
await asyncio.sleep(5)

def load_llm(self):
if 'openai' in self.configs and \
len(self.configs['openai']['key']) and \
self.configs['openai']['key'][0] is not None:
if 'openai' in self.config_helper.cached_config and \
len(self.config_helper.cached_config['openai']['key']) and \
self.config_helper.cached_config['openai']['key'][0] is not None:
from model.provider.openai_official import ProviderOpenAIOfficial
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
Expand Down
2 changes: 1 addition & 1 deletion astrbot/message/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, context: Context,
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
self.nicks = self.context.nick
self.provider = provider
self.reply_prefix = self.context.reply_prefix
self.reply_prefix = str(self.context.reply_prefix)

def set_provider(self, provider: Provider):
self.provider = provider
Expand Down
3 changes: 1 addition & 2 deletions dashboard/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,11 @@ def authenticate():

@self.dashboard_be.post("/api/change_password")
def change_password():
password = self.context.base_config("dashboard_password", "")
password = self.context.base_config.get("dashboard_password", "")
# 获得请求体
post_data = request.json
if post_data["password"] == password:
self.context.config_helper.put("dashboard_password", post_data["new_password"])
self.context.base_config['dashboard_password'] = post_data["new_password"]
return Response(
status="success",
message="修改成功。",
Expand Down
4 changes: 3 additions & 1 deletion model/command/internal_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,17 @@ def web_search(self, message: AstrMessageEvent, context: Context):
)

def t2i_toggle(self, message: AstrMessageEvent, context: Context):
p = context.config_helper.get("qq_pic_mode", True)
p = context.t2i_mode
if p:
context.config_helper.put("qq_pic_mode", False)
context.t2i_mode = False
return CommandResult(
hit=True,
success=True,
message_chain="已关闭文本转图片模式。",
)
context.config_helper.put("qq_pic_mode", True)
context.t2i_mode = True

return CommandResult(
hit=True,
Expand Down
6 changes: 3 additions & 3 deletions model/platform/qq_aiocqhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ async def handle_msg(self, message: AstrBotMessage):

# 解析 role
sender_id = str(message.sender.user_id)
if sender_id == self.context.config_helper.get('admin_qq', '') or \
sender_id in self.context.config_helper.get('other_admins', []):
if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.base_config.get('other_admins', []):
role = 'admin'
else:
role = 'member'
Expand Down Expand Up @@ -154,7 +154,7 @@ async def reply_msg(self,
res = [Plain(text=res), ]

# if image mode, put all Plain texts into a new picture.
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list):
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
Expand Down
8 changes: 4 additions & 4 deletions model/platform/qq_nakuru.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ async def handle_msg(self, message: AstrBotMessage):

# 解析 role
sender_id = str(message.raw_message.user_id)
if sender_id == self.context.config_helper.get('admin_qq', '') or \
sender_id in self.context.config_helper.get('other_admins', []):
if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.base_config.get('other_admins', []):
role = 'admin'
else:
role = 'member'
Expand Down Expand Up @@ -152,7 +152,7 @@ async def reply_msg(self,
res = [Plain(text=res), ]

# if image mode, put all Plain texts into a new picture.
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list):
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
Expand Down Expand Up @@ -186,7 +186,7 @@ async def _reply(self, source, message_chain: List[BaseMessageComponent]):
plain_text_len += len(i.text)
elif isinstance(i, Image):
image_num += 1
if plain_text_len > self.context.config_helper.get('qq_forward_threshold', 200):
if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
# 删除At
for i in message_chain:
if isinstance(i, At):
Expand Down
6 changes: 3 additions & 3 deletions model/platform/qq_official.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ async def handle_msg(self, message: AstrBotMessage):

# 解析出 role
sender_id = message.sender.user_id
if sender_id == self.context.config_helper.get('admin_qqchan', None) or \
sender_id in self.context.config_helper.get('other_admins', None):
if sender_id == self.context.base_config.get('admin_qqchan', None) or \
sender_id in self.context.base_config.get('other_admins', None):
role = 'admin'
else:
role = 'member'
Expand Down Expand Up @@ -249,7 +249,7 @@ async def reply_msg(self,
msg_ref = None
rendered_images = []

if self.context.config_helper.get("qq_pic_mode", False) and isinstance(result_message, list):
if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(result_message)

if isinstance(result_message, list):
Expand Down
6 changes: 3 additions & 3 deletions model/provider/openai_official.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, context: Context) -> None:

os.makedirs("data/openai", exist_ok=True)

self.cc = CmdConfig
self.context = context
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
Expand All @@ -78,7 +78,7 @@ def __init__(self, context: Context) -> None:
)
self.model_configs: Dict = cfg['chatGPTConfigs']
super().set_curr_model(self.model_configs['model'])
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
Expand Down Expand Up @@ -492,7 +492,7 @@ def dump_contexts_page(self, session_id: str, size=5, page=1,):

def set_model(self, model: str):
self.model_configs['model'] = model
self.cc.put_by_dot_str("openai.chatGPTConfigs.model", model)
self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
super().set_curr_model(model)

def get_configs(self):
Expand Down
76 changes: 75 additions & 1 deletion type/config.py
Original file line number Diff line number Diff line change
@@ -1 +1,75 @@
VERSION = '3.3.5'
VERSION = '3.3.7'

DEFAULT_CONFIG = {
"qqbot": {
"enable": False,
"appid": "",
"token": "",
},
"gocqbot": {
"enable": False,
},
"uniqueSessionMode": False,
"dump_history_interval": 10,
"limit": {
"time": 60,
"count": 30,
},
"notice": "",
"direct_message_mode": True,
"reply_prefix": "",
"baidu_aip": {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": ""
},
"openai": {
"key": [],
"api_base": "",
"chatGPTConfigs": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
"total_tokens_limit": 10000,
},
"qq_forward_threshold": 200,
"qq_welcome": "",
"qq_pic_mode": True,
"gocq_host": "127.0.0.1",
"gocq_http_port": 5700,
"gocq_websocket_port": 6700,
"gocq_react_group": True,
"gocq_react_guild": True,
"gocq_react_friend": True,
"gocq_react_group_increase": True,
"other_admins": [],
"CHATGPT_BASE_URL": "",
"qqbot_secret": "",
"qqofficial_enable_group_message": False,
"admin_qq": "",
"nick_qq": ["/", "!"],
"admin_qqchan": "",
"llm_env_prompt": "",
"llm_wake_prefix": "",
"default_personality_str": "",
"openai_image_generate": {
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
},
"http_proxy": "",
"https_proxy": "",
"dashboard_username": "",
"dashboard_password": "",
"aiocqhttp": {
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 0,
}
}
11 changes: 6 additions & 5 deletions type/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,22 @@ def __init__(self):

self.unique_session = False # 独立会话
self.version: str = None # 机器人版本
self.nick = None # gocq 的唤醒词
self.stat = {}
self.nick: tuple = None # gocq 的唤醒词
self.t2i_mode = False
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = ""

self.metrics_uploader = None
self.updator: AstrBotUpdator = None
self.plugin_updator: PluginUpdator = None
self.metrics_uploader = None

self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins)
self.image_renderer = TextToImageRenderer()
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []

# useless
self.reply_prefix = ""

def register_commands(self,
plugin_name: str,
command_name: str,
Expand Down
Loading

0 comments on commit 6170f82

Please sign in to comment.