From bb219889e5aa245ca727d826c14b54457b184783 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 16 Nov 2024 12:40:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B6=88=E6=81=AF=E5=B9=B3=E5=8F=B0?= =?UTF-8?q?=E7=83=AD=E9=87=8D=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 15 +++--- pkg/api/http/controller/groups/system.py | 22 +++++++++ pkg/api/http/controller/main.py | 15 +++++- pkg/audit/center/apigroup.py | 6 +-- pkg/core/app.py | 59 ++++++++++++++++++++---- pkg/core/bootutils/deps.py | 2 +- pkg/core/entities.py | 8 ++++ pkg/core/taskmgr.py | 22 +++++++-- pkg/pipeline/controller.py | 5 +- pkg/platform/manager.py | 8 ++-- pkg/platform/sources/qqbotpy.py | 8 +++- pkg/utils/ip.py | 9 ++++ requirements.txt | 2 +- web/src/App.vue | 25 ++++++++++ 14 files changed, 169 insertions(+), 37 deletions(-) create mode 100644 pkg/utils/ip.py diff --git a/main.py b/main.py index baf339c6..6164bf5b 100644 --- a/main.py +++ b/main.py @@ -3,13 +3,14 @@ # QChatGPT/main.py asciiart = r""" - ___ ___ _ _ ___ ___ _____ - / _ \ / __| |_ __ _| |_ / __| _ \_ _| -| (_) | (__| ' \/ _` | _| (_ | _/ | | - \__\_\\___|_||_\__,_|\__|\___|_| |_| - -⭐️开源地址: https://github.com/RockChinQ/QChatGPT -📖文档地址: https://q.rkcn.top + _ ___ _ +| | __ _ _ _ __ _| _ ) ___| |_ +| |__/ _` | ' \/ _` | _ \/ _ \ _| +|____\__,_|_||_\__, |___/\___/\__| + |___/ + +⭐️开源地址: https://github.com/RockChinQ/LangBot +📖文档地址: https://docs.langbot.app """ diff --git a/pkg/api/http/controller/groups/system.py b/pkg/api/http/controller/groups/system.py index a967d6b1..3b9c57fa 100644 --- a/pkg/api/http/controller/groups/system.py +++ b/pkg/api/http/controller/groups/system.py @@ -39,3 +39,25 @@ async def _(task_id: str) -> str: return self.http_status(404, 404, "Task not found") return self.success(data=task.to_dict()) + + @self.route('/reload', methods=['POST']) + async def _() -> str: + json_data = await quart.request.json + + scope = json_data.get("scope") + + await self.ap.reload( + scope=scope + ) + return self.success() + + @self.route('/_debug/exec', methods=['POST']) + async def _() -> str: + if not constants.debug_mode: + return self.http_status(403, 403, "Forbidden") + + py_code = await quart.request.data + + ap = self.ap + + return self.success(data=exec(py_code, {"ap": ap})) diff --git a/pkg/api/http/controller/main.py b/pkg/api/http/controller/main.py index d91f9afe..8befea43 100644 --- a/pkg/api/http/controller/main.py +++ b/pkg/api/http/controller/main.py @@ -6,7 +6,7 @@ import quart import quart_cors -from ....core import app +from ....core import app, entities as core_entities from .groups import logs, system, settings, plugins, stats from . import group @@ -32,15 +32,26 @@ async def shutdown_trigger_placeholder(): while True: await asyncio.sleep(1) + async def exception_handler(*args, **kwargs): + try: + await self.quart_app.run_task( + *args, **kwargs + ) + except Exception as e: + self.ap.logger.error(f"启动 HTTP 服务失败: {e}") + self.ap.task_mgr.create_task( - self.quart_app.run_task( + exception_handler( host=self.ap.system_cfg.data["http-api"]["host"], port=self.ap.system_cfg.data["http-api"]["port"], shutdown_trigger=shutdown_trigger_placeholder, ), name="http-api-quart", + scopes=[core_entities.LifecycleControlScope.APPLICATION], ) + # await asyncio.sleep(5) + async def register_routes(self) -> None: @self.quart_app.route("/healthz") diff --git a/pkg/audit/center/apigroup.py b/pkg/audit/center/apigroup.py index 3e3c5eb5..4b20a09a 100644 --- a/pkg/audit/center/apigroup.py +++ b/pkg/audit/center/apigroup.py @@ -9,7 +9,7 @@ import aiohttp import requests -from ...core import app +from ...core import app, entities as core_entities class APIGroup(metaclass=abc.ABCMeta): @@ -65,14 +65,12 @@ async def do( **kwargs, ) -> asyncio.Task: """执行请求""" - # task = asyncio.create_task(self._do(method, path, data, params, headers, **kwargs)) - - # self.ap.asyncio_tasks.append(task) return self.ap.task_mgr.create_task( self._do(method, path, data, params, headers, **kwargs), kind="telemetry-operation", name=f"{method} {path}", + scopes=[core_entities.LifecycleControlScope.APPLICATION], ).task def gen_rid(self): diff --git a/pkg/core/app.py b/pkg/core/app.py index 2d8afeb6..ead769f1 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -4,6 +4,7 @@ import asyncio import threading import traceback +import enum from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr @@ -21,8 +22,9 @@ from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..persistence import mgr as persistencemgr from ..api.http.controller import main as http_controller -from ..utils import logcache +from ..utils import logcache, ip from . import taskmgr +from . import entities as core_entities class Application: @@ -114,11 +116,12 @@ async def never_ending(): while True: await asyncio.sleep(1) - self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager") - self.task_mgr.create_task(self.ctrl.run(), name="query-controller") - self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller") - self.task_mgr.create_task(never_ending(), name="never-ending-task") + self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]) + self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION]) + self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION]) + self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION]) + await self.print_web_access_info() await self.task_mgr.wait_all() except asyncio.CancelledError: pass @@ -126,9 +129,45 @@ async def never_ending(): self.logger.error(f"应用运行致命异常: {e}") self.logger.debug(f"Traceback: {traceback.format_exc()}") - async def scoped_shutdown(self, scopes: list[str]): - pass + async def print_web_access_info(self): + """打印访问 webui 的提示""" + import socket + + host_ip = socket.gethostbyname(socket.gethostname()) + + public_ip = await ip.get_myip() + + port = self.system_cfg.data['http-api']['port'] + + tips = f""" +======================================= +✨ 您可通过以下方式访问管理面板: + +🏠 本地地址:http://{host_ip}:{port}/ +🌐 公网地址:http://{public_ip}:{port}/ + +📌 如果您在容器中运行此程序,请确保容器的 {port} 端口已对外暴露 +🔗 若要使用公网地址访问,请阅读以下须知 + 1. 公网地址仅供参考,请以您的主机公网 IP 为准; + 2. 要使用公网地址访问,请确保您的主机具有公网 IP,并且系统防火墙已放行 {port} 端口; +======================================= +""".strip() + for line in tips.split("\n"): + self.logger.info(line) + + async def reload( + self, + scope: core_entities.LifecycleControlScope, + ): + match scope: + case core_entities.LifecycleControlScope.PLATFORM.value: + self.logger.info("执行热重载 scope="+scope) + await self.platform_mgr.shutdown() + + self.platform_mgr = im_mgr.PlatformManager(self) + + await self.platform_mgr.initialize() - async def shutdown(self): - for task in self.task_mgr.tasks: - task.cancel() + self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]) + case _: + pass diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 3dd18c58..56938b08 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -6,7 +6,7 @@ "anthropic": "anthropic", "colorlog": "colorlog", "aiocqhttp": "aiocqhttp", - "botpy": "qq-botpy", + "botpy": "qq-botpy-rc", "PIL": "pillow", "nakuru": "nakuru-project-idk", "tiktoken": "tiktoken", diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 67b05666..464384be 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -17,6 +17,14 @@ from ..platform.types import entities as platform_entities + +class LifecycleControlScope(enum.Enum): + + APPLICATION = "application" + PLATFORM = "platform" + PLUGIN = "plugin" + + class LauncherTypes(enum.Enum): """一个请求的发起者类型""" diff --git a/pkg/core/taskmgr.py b/pkg/core/taskmgr.py index 210b2ab6..2c029c03 100644 --- a/pkg/core/taskmgr.py +++ b/pkg/core/taskmgr.py @@ -6,6 +6,7 @@ import traceback from . import app +from . import entities as core_entities class TaskContext: @@ -71,7 +72,7 @@ class TaskWrapper: task_type: str = "system" # 任务类型: system 或 user """任务类型""" - kind: str = "system_task" + kind: str = "system_task" # 由发起者确定任务种类,通常同质化的任务种类相同 """任务种类""" name: str = "" @@ -92,6 +93,9 @@ class TaskWrapper: ap: app.Application """应用实例""" + scopes: list[core_entities.LifecycleControlScope] + """任务所属生命周期控制范围""" + def __init__( self, ap: app.Application, @@ -101,6 +105,7 @@ def __init__( name: str = "", label: str = "", context: TaskContext = None, + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ): self.id = TaskWrapper._id_index TaskWrapper._id_index += 1 @@ -112,6 +117,7 @@ def __init__( self.name = name self.label = label if label != "" else name self.task.set_name(name) + self.scopes = scopes def assume_exception(self): try: @@ -145,6 +151,7 @@ def to_dict(self) -> dict: "kind": self.kind, "name": self.name, "label": self.label, + "scopes": [scope.value for scope in self.scopes], "task_context": self.task_context.to_dict(), "runtime": { "done": self.task.done(), @@ -180,8 +187,9 @@ def create_task( name: str = "", label: str = "", context: TaskContext = None, + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ) -> TaskWrapper: - wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context) + wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes) self.tasks.append(wrapper) return wrapper @@ -192,8 +200,9 @@ def create_user_task( name: str = "", label: str = "", context: TaskContext = None, + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ) -> TaskWrapper: - return self.create_task(coro, "user", kind, name, label, context) + return self.create_task(coro, "user", kind, name, label, context, scopes) async def wait_all(self): await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True) @@ -217,3 +226,10 @@ def get_task_by_id(self, id: int) -> TaskWrapper | None: if t.id == id: return t return None + + def cancel_by_scope(self, scope: core_entities.LifecycleControlScope): + for wrapper in self.tasks: + + if not wrapper.task.done() and scope in wrapper.scopes: + + wrapper.task.cancel() diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index c5598a08..92ba8173 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -4,7 +4,6 @@ import typing import traceback - from ..core import app, entities from . import entities as pipeline_entities from ..plugin import events @@ -59,13 +58,11 @@ async def _process_query(selected_query): (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() # 通知其他协程,有新的请求可以处理了 self.ap.query_pool.condition.notify_all() - - # task = asyncio.create_task(_process_query(selected_query)) - # self.ap.asyncio_tasks.append(task) self.ap.task_mgr.create_task( _process_query(selected_query), kind="query", name=f"query-{selected_query.query_id}", + scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM], ) except Exception as e: diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 2e60a0cd..b33c1e55 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -174,22 +174,23 @@ async def run(self): try: tasks = [] for adapter in self.adapters: - async def exception_wrapper(adapter): + async def exception_wrapper(adapter: msadapter.MessageSourceAdapter): try: await adapter.run_async() except Exception as e: + if isinstance(e, asyncio.CancelledError): + return self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") tasks.append(exception_wrapper(adapter)) for task in tasks: - # async_task = asyncio.create_task(task) - # self.ap.asyncio_tasks.append(async_task) self.ap.task_mgr.create_task( task, kind="platform-adapter", name=f"platform-adapter-{adapter.name}", + scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM], ) except Exception as e: @@ -199,3 +200,4 @@ async def exception_wrapper(adapter): async def shutdown(self): for adapter in self.adapters: await adapter.kill() + self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM) \ No newline at end of file diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index b91377a1..2923e770 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -588,8 +588,12 @@ async def run_async(self): self.member_openid_mapping, self.group_openid_mapping ) + self.cfg['ret_coro'] = True + self.ap.logger.info("运行 QQ 官方适配器") - await self.bot.start(**self.cfg) + await (await self.bot.start(**self.cfg)) async def kill(self) -> bool: - return False + if not self.bot.is_closed(): + await self.bot.close() + return True diff --git a/pkg/utils/ip.py b/pkg/utils/ip.py new file mode 100644 index 00000000..4f54bad2 --- /dev/null +++ b/pkg/utils/ip.py @@ -0,0 +1,9 @@ +import aiohttp + +async def get_myip() -> str: + try: + async with aiohttp.ClientSession() as session: + async with session.get("https://ip.useragentinfo.com/myip") as response: + return await response.text() + except Exception as e: + return '0.0.0.0' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cd55555d..7eaec08a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ openai>1.0.0 anthropic colorlog~=6.6.0 aiocqhttp -qq-botpy +qq-botpy-rc nakuru-project-idk Pillow tiktoken diff --git a/web/src/App.vue b/web/src/App.vue index 00424e70..9e6ff8ca 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -67,6 +67,12 @@ + + + 重载消息平台 + + + @@ -137,6 +143,25 @@ function openDocs() { window.open('https://docs.langbot.app', '_blank') } +function reload(scope) { + proxy.$axios.post('/system/reload', + { scope: scope }, + { headers: { 'Content-Type': 'application/json' } } + ).then(response => { + if (response.data.code === 0) { + success('消息平台已重载') + + // 关闭菜单 + } else { + error('消息平台重载失败:' + response.data.message) + } + }).catch(error => { + console.error(error) + error('消息平台重载失败:' + error) + }) + +} + const aboutDialogShow = ref(false) function showAboutDialog() {