Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

初步实现代码执行器 #232

Merged
merged 9 commits into from
Jan 9, 2025
2 changes: 1 addition & 1 deletion astrbot/core/core_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def initialize(self):
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)

self.plugin_manager.reload()
await self.plugin_manager.reload()
'''扫描、注册插件、实例化插件类'''

await self.provider_manager.initialize()
Expand Down
15 changes: 14 additions & 1 deletion astrbot/core/message/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ComponentType(Enum):
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"


class BaseMessageComponent(BaseModel):
Expand Down Expand Up @@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent):
def toString(self):
return ""

class File(BaseMessageComponent):
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url(本地路径)

def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)


ComponentTypes = {
"plain": Plain,
Expand All @@ -441,5 +453,6 @@ def toString(self):
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown
"unknown": Unknown,
'file': File,
}
5 changes: 4 additions & 1 deletion astrbot/core/pipeline/process_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[N
yield

# 调用提供商相关请求
if self.ctx.astrbot_config['provider_settings'].get('enable', True) and not event._has_send_oper:
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return

if not event._has_send_oper and event.is_at_or_wake_command:
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
provider = self.ctx.plugin_manager.context.get_using_provider()
match provider.meta().type:
Expand Down
3 changes: 3 additions & 0 deletions astrbot/core/pipeline/waking_check/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def process(
# 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒
break
is_wake = True
event.is_at_or_wake_command = True
event.is_wake = True
event.message_str = event.message_str[len(wake_prefix) :].strip()
break
Expand All @@ -60,11 +61,13 @@ async def process(
is_wake = True
event.is_wake = True
wake_prefix = ""
event.is_at_or_wake_command = True
break
# 检查是否是私聊
if event.is_private_chat():
is_wake = True
event.is_wake = True
event.is_at_or_wake_command = True
wake_prefix = ""

# 检查插件的 handler filter
Expand Down
3 changes: 2 additions & 1 deletion astrbot/core/platform/astr_message_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self,
self.platform_meta = platform_meta
self.session_id = session_id
self.role = "member"
self.is_wake = False
self.is_wake = False # 是否通过 WakingStage
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True)
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import os
import time
import asyncio
import logging
from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from .aiocqhttp_message_event import *
from astrbot.api.message_components import *
from .aiocqhttp_message_event import * # noqa: F403
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed

@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
Expand Down Expand Up @@ -42,7 +44,7 @@ async def send_by_session(self, session: MessageSesion, message_chain: MessageCh
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)

def convert_message(self, event: Event) -> AstrBotMessage:
async def convert_message(self, event: Event) -> AstrBotMessage:
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.tag = "aiocqhttp"
Expand Down Expand Up @@ -78,7 +80,25 @@ def convert_message(self, event: Event) -> AstrBotMessage:
a = None
if t == 'text':
message_str += m['data']['text'].strip()
a = ComponentTypes[t](**m['data'])
elif t == 'file':
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")

m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")

a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
Expand All @@ -91,13 +111,13 @@ def run(self) -> Awaitable[Any]:
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)

@self.bot.on_message('private')
async def private(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ async def send(self, message: MessageChain):
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if image_path:
Expand Down Expand Up @@ -73,9 +75,9 @@ async def _parse_to_qqofficial(message: MessageChain):
plain_text += i.text
elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"):
image_base64 = file_to_base64(i.file[8:])
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
return plain_text, image_base64, image_file_path
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import uuid
import asyncio
import os

from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
Expand Down Expand Up @@ -62,7 +63,7 @@ async def _(msg: model.Message):
self.start_time = int(time.time())
return self._run()


async def _run(self):
await self.client.init()
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
Expand Down
1 change: 0 additions & 1 deletion astrbot/core/provider/sources/dify_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import base64
from typing import List
from .. import Provider
from ..entites import LLMResponse
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/star/register/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def decorator(awaitable: Awaitable):
"description": arg.description
})
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler)
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)

logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return awaitable
Expand Down
10 changes: 7 additions & 3 deletions astrbot/core/star/star_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMeta

return metadata

def reload(self):
async def reload(self):
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
Expand Down Expand Up @@ -231,6 +231,10 @@ def reload(self):
if metadata.module_path in inactivated_plugins:
metadata.activated = False

# 执行 initialize 函数
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()

except BaseException as e:
traceback.print_exc()
fail_rec += f"加载 {path} 插件时出现问题,原因 {str(e)}\n"
Expand All @@ -247,7 +251,7 @@ def reload(self):
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.install(repo_url)
# reload the plugin
self.reload()
await self.reload()
return plugin_path

async def uninstall_plugin(self, plugin_name: str):
Expand Down Expand Up @@ -288,7 +292,7 @@ async def update_plugin(self, plugin_name: str):
raise Exception("该插件是 AstrBot 保留插件,无法更新。")

await self.updator.update(plugin)
self.reload()
await self.reload()

async def turn_off_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
Expand Down
3 changes: 3 additions & 0 deletions astrbot/core/utils/param_validation_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def validate_and_convert_params(self, params: List[Any], param_type: Dict[str, T
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
Expand Down
Loading
Loading