diff --git a/.gitignore b/.gitignore index c8381da506..02f08b760f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,6 @@ test* cache/* assets/arc.apk assets/arcaea -assets/cytoid-avatar/ -assets/cytoid-cover/ assets/maimai* !assets/maimai/mai_utage_info.json !assets/maimai/mai_grade_info.json diff --git a/bot.py b/bot.py index 1d6c0ee182..4d1dfc1b54 100644 --- a/bot.py +++ b/bot.py @@ -7,16 +7,11 @@ from datetime import datetime from time import sleep -import core.scripts.config_generate # noqa -from core.config import Config, CFGManager -from core.constants.default import base_superuser_default -from core.constants.path import cache_path -from core.database import BotDBUtil, session, DBVersion -from core.logger import Logger -from core.utils.info import Info +from loguru import logger as loggerFallback + ascii_art = r''' - ._. _ .____ ._. + _ _ ____ _ /\ | | (_) | _ \ | | / \ | | ____ _ _ __ _ | |_) | ___ | |_ / /\ \ | |/ / _` | '__| | | _ < / _ \| __| @@ -49,9 +44,29 @@ class RestartBot(Exception): failed_to_start_attempts = {} +disabled_bots = [] +processes = [] def init_bot(): + import core.scripts.config_generate # noqa + from core.config import Config, CFGManager # noqa + from core.constants.default import base_superuser_default # noqa + from core.database import BotDBUtil, session, DBVersion # noqa + from core.logger import Logger # noqa + + query_dbver = session.query(DBVersion).first() + if not query_dbver: + session.add_all([DBVersion(value=str(BotDBUtil.database_version))]) + session.commit() + query_dbver = session.query(DBVersion).first() + if (current_ver := int(query_dbver.value)) < (target_ver := BotDBUtil.database_version): + Logger.info(f'Updating database from {current_ver} to {target_ver}...') + from core.database.update import update_database + + update_database() + Logger.info('Database updated successfully!') + print(ascii_art) base_superuser = Config('base_superuser', base_superuser_default, cfg_type=(str, list)) if base_superuser: if isinstance(base_superuser, str): @@ -59,10 +74,35 @@ def init_bot(): for bu in base_superuser: BotDBUtil.SenderInfo(bu).init() BotDBUtil.SenderInfo(bu).edit('isSuperUser', True) - print(ascii_art) + else: + Logger.warning("The base superuser was not found, please setup it in the config file.") + + disabled_bots.clear() + for t in CFGManager.values: + if t.startswith('bot_') and not t.endswith('_secret'): + if 'enable' in CFGManager.values[t][t]: + if not CFGManager.values[t][t]['enable']: + disabled_bots.append(t[4:]) + + +def multiprocess_run_until_complete(func): + p = multiprocessing.Process( + target=func,) + p.start() + + while True: + if not p.is_alive(): + break + sleep(1) + p.terminate() + p.join() + p.close() def go(bot_name: str = None, subprocess: bool = False, binary_mode: bool = False): + from core.logger import Logger # noqa + from core.utils.info import Info # noqa + Logger.info(f"[{bot_name}] Here we go!") Info.subprocess = subprocess Info.binary_mode = binary_mode @@ -76,41 +116,34 @@ def go(bot_name: str = None, subprocess: bool = False, binary_mode: bool = False sys.exit(1) -disabled_bots = [] -processes = [] - -for t in CFGManager.values: - if t.startswith('bot_') and not t.endswith('_secret'): - if 'enable' in CFGManager.values[t][t]: - if not CFGManager.values[t][t]['enable']: - disabled_bots.append(t[4:]) - - -def restart_process(bot_name: str): - if bot_name not in failed_to_start_attempts or datetime.now( - ).timestamp() - failed_to_start_attempts[bot_name]['timestamp'] > 60: - failed_to_start_attempts[bot_name] = {} - failed_to_start_attempts[bot_name]['count'] = 0 +def run_bot(): + from core.constants.path import cache_path # noqa + from core.config import Config # noqa + from core.logger import Logger # noqa + + def restart_process(bot_name: str): + if bot_name not in failed_to_start_attempts or datetime.now( + ).timestamp() - failed_to_start_attempts[bot_name]['timestamp'] > 60: + failed_to_start_attempts[bot_name] = {} + failed_to_start_attempts[bot_name]['count'] = 0 + failed_to_start_attempts[bot_name]['timestamp'] = datetime.now().timestamp() + failed_to_start_attempts[bot_name]['count'] += 1 failed_to_start_attempts[bot_name]['timestamp'] = datetime.now().timestamp() - failed_to_start_attempts[bot_name]['count'] += 1 - failed_to_start_attempts[bot_name]['timestamp'] = datetime.now().timestamp() - if failed_to_start_attempts[bot_name]['count'] >= 3: - Logger.error(f'Bot {bot_name} failed to start 3 times, abort to restart, please check the log.') - return - - Logger.warning(f'Restarting bot {bot_name}...') - p = multiprocessing.Process( - target=go, - args=( - bot_name, - True, - bool(not sys.argv[0].endswith('.py'))), - name=bot_name) - p.start() - processes.append(p) + if failed_to_start_attempts[bot_name]['count'] >= 3: + Logger.error(f'Bot {bot_name} failed to start 3 times, abort to restart, please check the log.') + return + Logger.warning(f'Restarting bot {bot_name}...') + p = multiprocessing.Process( + target=go, + args=( + bot_name, + True, + bool(not sys.argv[0].endswith('.py'))), + name=bot_name) + p.start() + processes.append(p) -def run_bot(): if os.path.exists(cache_path): shutil.rmtree(cache_path) os.makedirs(cache_path, exist_ok=True) @@ -153,45 +186,37 @@ def run_bot(): processes.remove(p) p.terminate() p.join() + p.close() restart_process(p.name) break if not processes: break sleep(1) + Logger.critical('All bots exited unexpectedly, please check the output.') if __name__ == '__main__': - query_dbver = session.query(DBVersion).first() - if not query_dbver: - session.add_all([DBVersion(value=str(BotDBUtil.database_version))]) - session.commit() - query_dbver = session.query(DBVersion).first() - if (current_ver := int(query_dbver.value)) < (target_ver := BotDBUtil.database_version): - Logger.info(f'Updating database from {current_ver} to {target_ver}...') - from core.database.update import update_database - - update_database() - Logger.info('Database updated successfully!') - init_bot() try: while True: try: + multiprocess_run_until_complete(init_bot) run_bot() # Process will block here so - Logger.critical('All bots exited unexpectedly, please check the output.') break except RestartBot: for ps in processes: ps.terminate() ps.join() + ps.close() processes.clear() continue except Exception: - Logger.critical('An error occurred, please check the output.') + loggerFallback.critical('An error occurred, please check the output.') traceback.print_exc() break except (KeyboardInterrupt, SystemExit): for ps in processes: ps.terminate() ps.join() + ps.close() processes.clear() diff --git a/bots/qqbot/message.py b/bots/qqbot/message.py index 1635be9969..b583cf9ee1 100644 --- a/bots/qqbot/message.py +++ b/bots/qqbot/message.py @@ -22,13 +22,16 @@ class FinishedSession(FinishedSessionT): async def delete(self): - if self.session.target.target_from == target_guild_prefix: - try: - from bots.qqbot.bot import client # noqa + try: + from bots.qqbot.bot import client # noqa + if self.session.target.target_from == target_guild_prefix: for x in self.message_id: await client.api.recall_message(channel_id=self.session.target.target_id.split('|')[-1], message_id=x, hidetip=True) - except Exception: - Logger.error(traceback.format_exc()) + elif self.session.target.target_from == target_group_prefix: + for x in self.message_id: + await client.api.recall_group_message(group_openid=self.session.target.target_id.split('|')[-1], message_id=x) + except Exception: + Logger.error(traceback.format_exc()) class MessageSession(MessageSessionT): @@ -226,6 +229,8 @@ async def delete(self): except Exception: Logger.error(traceback.format_exc()) return False + else: + return False sendMessage = send_message asDisplay = as_display @@ -253,7 +258,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): class FetchedSession(Bot.FetchedSession): async def send_direct_message(self, message_chain, disable_secret_check=False, enable_parse_message=True, enable_split_image=True): - from bots.qqbot.bot import client + from bots.qqbot.bot import client # noqa if self.target.target_from == target_guild_prefix: self.session.message = Message(api=client.api, event_id=None, data={ "channel_id": self.target.target_id.split('|')[-1]}) diff --git a/core/builtins/__init__.py b/core/builtins/__init__.py index 1391cac4d0..f4f0f1dd70 100644 --- a/core/builtins/__init__.py +++ b/core/builtins/__init__.py @@ -23,6 +23,7 @@ class Bot: ModuleHookContext = ModuleHookContext ExecutionLockList = ExecutionLockList Info = Info + Temp = Temp @staticmethod async def send_message(target: Union[FetchedSession, MessageSession, str], @@ -37,7 +38,10 @@ async def send_message(target: Union[FetchedSession, MessageSession, str], if isinstance(msg, list): msg = MessageChain(msg) Logger.info(target.__dict__) - await target.send_direct_message(msg, disable_secret_check, enable_split_image) + await target.send_direct_message(message_chain=msg, + disable_secret_check=disable_secret_check, + enable_parse_message=enable_parse_message, + enable_split_image=enable_split_image) @staticmethod async def fetch_target(target: str): diff --git a/core/builtins/message/__init__.py b/core/builtins/message/__init__.py index 902ba601e6..290128882e 100644 --- a/core/builtins/message/__init__.py +++ b/core/builtins/message/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from datetime import datetime, UTC as datetimeUTC from typing import Any, Coroutine, Dict, List, Optional, Union @@ -16,21 +18,24 @@ class ExecutionLockList: + """ + 执行锁。 + """ _list = set() @staticmethod - def add(msg: 'MessageSession'): + def add(msg: MessageSession): target_id = msg.target.sender_id ExecutionLockList._list.add(target_id) @staticmethod - def remove(msg: 'MessageSession'): + def remove(msg: MessageSession): target_id = msg.target.sender_id if target_id in ExecutionLockList._list: ExecutionLockList._list.remove(target_id) @staticmethod - def check(msg: 'MessageSession'): + def check(msg: MessageSession): target_id = msg.target.sender_id return target_id in ExecutionLockList._list @@ -40,66 +45,75 @@ def get(): class MessageTaskManager: - _list = {} + """ + 消息计划管理器。 + """ + _task_list = {} _callback_list = {} @classmethod - def add_task(cls, session: 'MessageSession', flag, all_=False, reply=None, timeout=120): + def add_task( + cls, + session: MessageSession, + flag: asyncio.Event, + all_: bool = False, + reply: Optional[Union[List[int], List[str], int, str]] = None, + timeout: Optional[float] = 120): sender = session.target.sender_id task_type = 'reply' if reply else 'wait' if all_: sender = 'all' - if session.target.target_id not in cls._list: - cls._list[session.target.target_id] = {} - if sender not in cls._list[session.target.target_id]: - cls._list[session.target.target_id][sender] = {} - cls._list[session.target.target_id][sender][session] = { + if session.target.target_id not in cls._task_list: + cls._task_list[session.target.target_id] = {} + if sender not in cls._task_list[session.target.target_id]: + cls._task_list[session.target.target_id][sender] = {} + cls._task_list[session.target.target_id][sender][session] = { 'flag': flag, 'active': True, 'type': task_type, 'reply': reply, 'ts': datetime.now().timestamp(), 'timeout': timeout} - Logger.debug(cls._list) + Logger.debug(cls._task_list) @classmethod - def add_callback(cls, message_id, callback): + def add_callback(cls, message_id: Union[List[int], List[str], int, str], callback: Optional[Coroutine]): cls._callback_list[message_id] = {'callback': callback, 'ts': datetime.now().timestamp()} @classmethod - def get_result(cls, session: 'MessageSession'): - if 'result' in cls._list[session.target.target_id][session.target.sender_id][session]: - return cls._list[session.target.target_id][session.target.sender_id][session]['result'] + def get_result(cls, session: MessageSession): + if 'result' in cls._task_list[session.target.target_id][session.target.sender_id][session]: + return cls._task_list[session.target.target_id][session.target.sender_id][session]['result'] else: return None @classmethod def get(cls): - return cls._list + return cls._task_list @classmethod async def bg_check(cls): - for target in cls._list: - for sender in cls._list[target]: - for session in cls._list[target][sender]: - if cls._list[target][sender][session]['active']: - if (datetime.now().timestamp() - cls._list[target][sender][session]['ts'] > - cls._list[target][sender][session].get('timeout', 3600)): - cls._list[target][sender][session]['active'] = False - cls._list[target][sender][session]['flag'].set() # no result = cancel + for target in cls._task_list: + for sender in cls._task_list[target]: + for session in cls._task_list[target][sender]: + if cls._task_list[target][sender][session]['active']: + if (datetime.now().timestamp() - cls._task_list[target][sender][session]['ts'] > + cls._task_list[target][sender][session].get('timeout', 3600)): + cls._task_list[target][sender][session]['active'] = False + cls._task_list[target][sender][session]['flag'].set() # no result = cancel for message_id in cls._callback_list.copy(): if datetime.now().timestamp() - cls._callback_list[message_id]['ts'] > 3600: del cls._callback_list[message_id] @classmethod - async def check(cls, session: 'MessageSession'): - if session.target.target_id in cls._list: + async def check(cls, session: MessageSession): + if session.target.target_id in cls._task_list: senders = [] - if session.target.sender_id in cls._list[session.target.target_id]: + if session.target.sender_id in cls._task_list[session.target.target_id]: senders.append(session.target.sender_id) - if 'all' in cls._list[session.target.target_id]: + if 'all' in cls._task_list[session.target.target_id]: senders.append('all') if senders: for sender in senders: - for s in cls._list[session.target.target_id][sender]: - get_ = cls._list[session.target.target_id][sender][s] + for s in cls._task_list[session.target.target_id][sender]: + get_ = cls._task_list[session.target.target_id][sender][s] if get_['type'] == 'wait': get_['result'] = session get_['active'] = False @@ -122,6 +136,9 @@ async def check(cls, session: 'MessageSession'): class FinishedSession: + """ + 结束会话。 + """ def __init__(self, session, message_id: Union[List[int], List[str], int, str], result): self.session = session if isinstance(message_id, (int, str)): @@ -140,14 +157,18 @@ def __str__(self): class MessageSession: + """ + 消息会话。 + """ + def __init__(self, target: MsgInfo, session: Session): self.target = target self.session = session self.sent: List[MessageChain] = [] - self.parsed_msg: Optional[dict] = None self.trigger_msg: Optional[str] = None + self.parsed_msg: Optional[dict] = None self.prefixes: List[str] = [] self.data = exports.get("BotDBUtil").TargetInfo(self.target.target_id) self.info = exports.get("BotDBUtil").SenderInfo(self.target.sender_id) @@ -184,7 +205,7 @@ async def send_message(self, raise NotImplementedError async def finish(self, - message_chain: Union[MessageChain, str, list, MessageElement] = None, + message_chain: Optional[Union[MessageChain, str, list, MessageElement]] = None, quote: bool = True, disable_secret_check: bool = False, enable_parse_message: bool = True, @@ -193,7 +214,7 @@ async def finish(self, """ 用于向消息发送者返回消息并终结会话(模块后续代码不再执行)。 - :param message_chain: 消息链,若传入str则自动创建一条带有Plain元素的消息链。 + :param message_chain: 消息链,若传入str则自动创建一条带有Plain元素的消息链,可不填。 :param quote: 是否引用传入dict中的消息。(默认为True) :param disable_secret_check: 是否禁用消息检查。(默认为False) :param enable_parse_message: 是否允许解析消息。(此参数作接口兼容用,仅QQ平台使用,默认为True) @@ -259,6 +280,8 @@ async def check_native_permission(self) -> bool: async def fake_forward_msg(self, nodelist: List[Dict[str, Union[str, Any]]]): """ 用于发送假转发消息(QQ)。 + + :param nodelist: 消息段列表,即`type`键名为`node`的字典列表,详情参考OneBot文档。 """ raise NotImplementedError @@ -269,7 +292,7 @@ async def get_text_channel_list(self) -> List[str]: raise NotImplementedError class Typing: - def __init__(self, msg: 'MessageSession'): + def __init__(self, msg: MessageSession): """ :param msg: 本条消息,由于此class需要被一同传入下游方便调用,所以作为子class存在,将来可能会有其他的解决办法。 """ @@ -287,7 +310,7 @@ async def wait_confirm(self, message_chain: Optional[Union[MessageChain, str, list, MessageElement]] = None, quote: bool = True, delete: bool = True, - timeout: int = 120, + timeout: Optional[float] = 120, append_instruction: bool = True) -> bool: """ 一次性模板,用于等待触发对象确认。 @@ -330,8 +353,8 @@ async def wait_next_message(self, message_chain: Optional[Union[MessageChain, str, list, MessageElement]] = None, quote: bool = True, delete: bool = False, - timeout: int = 120, - append_instruction: bool = True) -> 'MessageSession': + timeout: Optional[float] = 120, + append_instruction: bool = True) -> MessageSession: """ 一次性模板,用于等待对象的下一条消息。 @@ -369,9 +392,9 @@ async def wait_reply(self, message_chain: Union[MessageChain, str, list, MessageElement], quote: bool = True, delete: bool = False, - timeout: int = 120, + timeout: Optional[float] = 120, all_: bool = False, - append_instruction: bool = True) -> 'MessageSession': + append_instruction: bool = True) -> MessageSession: """ 一次性模板,用于等待触发对象回复消息。 @@ -410,7 +433,7 @@ async def wait_anyone(self, message_chain: Optional[Union[MessageChain, str, list, MessageElement]] = None, quote: bool = False, delete: bool = False, - timeout: int = 120) -> 'MessageSession': + timeout: Optional[float] = 120) -> MessageSession: """ 一次性模板,用于等待触发对象所属对话内任意成员确认。 @@ -529,6 +552,10 @@ class Feature: class FetchedSession: + """ + 获取消息会话。 + """ + def __init__(self, target_from: str, target_id: Union[str, int], @@ -568,6 +595,9 @@ async def send_direct_message(self, class FetchTarget: + """ + 获取消息会话对象。 + """ name = '' @staticmethod diff --git a/core/builtins/message/chain.py b/core/builtins/message/chain.py index 5a1d00f56e..00be41f803 100644 --- a/core/builtins/message/chain.py +++ b/core/builtins/message/chain.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import re from typing import List, Optional, Tuple, Union, Any @@ -32,7 +34,7 @@ def __init__( List[MessageElement], Tuple[MessageElement], MessageElement, - 'MessageChain' + MessageChain ]] = None, ): """ @@ -146,7 +148,7 @@ def unsafeprompt(name, secret, text): return False return True - def as_sendable(self, msg: 'MessageSession' = None, embed: bool = True) -> list: + def as_sendable(self, msg: MessageSession = None, embed: bool = True) -> list: """ 将消息链转换为可发送的格式。 """ diff --git a/core/builtins/message/elements.py b/core/builtins/message/elements.py index a2df406b5b..56c70a5b3d 100644 --- a/core/builtins/message/elements.py +++ b/core/builtins/message/elements.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import base64 import os import random import re from datetime import datetime, timezone -from typing import Tuple, Optional, TYPE_CHECKING, Dict, Any, Union, List +from typing import Optional, TYPE_CHECKING, Dict, Any, Union, List from urllib import parse import aiohttp @@ -19,6 +21,8 @@ from core.utils.cache import random_cache_path from core.utils.i18n import Locale +from copy import deepcopy + if TYPE_CHECKING: from core.builtins import MessageSession @@ -41,20 +45,16 @@ class PlainElement(MessageElement): @classmethod def assign(cls, - *texts: Tuple[str], - disable_joke: bool = False, - comment: Optional[str] = None): + *texts: Any, + disable_joke: bool = False): """ :param texts: 文本内容。 :param disable_joke: 是否禁用玩笑功能。(默认为False) - :param comment: 注释文本,不受玩笑功能影响。 """ text = ''.join([str(x) for x in texts]) if not disable_joke: text = joke(text) - if comment: - text += '\n' + comment - return cls(text=text) + return deepcopy(cls(text=text)) @define @@ -80,7 +80,7 @@ def assign(cls, url: str, use_mm: bool = False): "nopqrstuvwxyzabcdefghijklmNOPQRSTUVWXYZABCDEFGHIJKLM") url = mm_url % parse.quote(parse.unquote(url).translate(rot13)) - return cls(url=url) + return deepcopy(cls(url=url)) def __str__(self): if self.md_format: @@ -102,7 +102,7 @@ class FormattedTimeElement(MessageElement): seconds: bool = True timezone: bool = True - def to_str(self, msg: Optional['MessageSession'] = None): + def to_str(self, msg: Optional[MessageSession] = None): ftime_template = [] if msg: if self.date: @@ -149,7 +149,7 @@ def assign(cls, timestamp: float, :param seconds: 是否显示秒。(默认为True) :param timezone: 是否显示时区。(默认为True) """ - return cls(timestamp=timestamp, date=date, iso=iso, time=time, seconds=seconds, timezone=timezone) + return deepcopy(cls(timestamp=timestamp, date=date, iso=iso, time=time, seconds=seconds, timezone=timezone)) @define @@ -168,7 +168,7 @@ def assign(cls, :param key: 多语言的键名。 :param kwargs: 多语言中的变量。 """ - return cls(key=key, kwargs=kwargs) + return deepcopy(cls(key=key, kwargs=kwargs)) @define @@ -207,7 +207,7 @@ def assign(cls, error_message += '\n' + \ locale.t('error.prompt.address', url=str(report_url)) - return cls(error_message) + return deepcopy(cls(error_message)) def __str__(self): return self.error_message @@ -238,7 +238,7 @@ def assign(cls, path: Union[str, PILImage.Image], path = save elif re.match('^https?://.*', path): need_get = True - return cls(path, need_get, headers) + return deepcopy(cls(path, need_get, headers)) async def get(self): """ @@ -298,7 +298,7 @@ def assign(cls, path: str): """ :param path: 语音路径。 """ - return cls(path) + return deepcopy(cls(path)) @define @@ -321,7 +321,7 @@ def assign(cls, name: str, value: str, inline: bool = False): :param value: 字段值。 :param inline: 是否内联。(默认为False) """ - return cls(name=name, value=value, inline=inline) + return deepcopy(cls(name=name, value=value, inline=inline)) @define @@ -359,7 +359,7 @@ def assign(cls, title: Optional[str] = None, author: Optional[str] = None, footer: Optional[str] = None, fields: Optional[List[EmbedFieldElement]] = None): - return cls( + return deepcopy(cls( title=title, description=description, url=url, @@ -369,9 +369,9 @@ def assign(cls, title: Optional[str] = None, thumbnail=thumbnail, author=author, footer=footer, - fields=fields) + fields=fields)) - def to_message_chain(self, msg: Optional['MessageSession'] = None): + def to_message_chain(self, msg: Optional[MessageSession] = None): """ 将Embed转换为消息链。 """ diff --git a/core/component.py b/core/component.py index f80551cc0e..7098c5f03d 100644 --- a/core/component.py +++ b/core/component.py @@ -184,22 +184,22 @@ def module( :param exclude_from: 此命令排除的平台列表。 :param support_languages: 此命令支持的语言列表。 """ - module = Module(alias=alias, - bind_prefix=bind_prefix, - desc=desc, - recommend_modules=recommend_modules, - developers=developers, - base=base, - doc=doc, - hidden=hidden, - load=load, - rss=rss, - required_admin=required_admin, - required_superuser=required_superuser, - required_base_superuser=required_base_superuser, - available_for=available_for, - exclude_from=exclude_from, - support_languages=support_languages) + module = Module.assign(alias=alias, + bind_prefix=bind_prefix, + desc=desc, + recommend_modules=recommend_modules, + developers=developers, + base=base, + doc=doc, + hidden=hidden, + load=load, + rss=rss, + required_admin=required_admin, + required_superuser=required_superuser, + required_base_superuser=required_base_superuser, + available_for=available_for, + exclude_from=exclude_from, + support_languages=support_languages) frame = inspect.currentframe() ModulesManager.add_module(module, frame.f_back.f_globals["__name__"]) return Bind.Module(bind_prefix) diff --git a/core/config/__init__.py b/core/config/__init__.py index b1d15a5b4f..4da9637b97 100644 --- a/core/config/__init__.py +++ b/core/config/__init__.py @@ -9,6 +9,7 @@ from tomlkit.exceptions import KeyAlreadyPresent from tomlkit.items import Table +from . import update # noqa from core.constants.default import default_locale from core.constants.exceptions import ConfigValueError, ConfigOperationError from core.constants.path import config_path @@ -110,11 +111,9 @@ def get(cls, :param cfg_type: 配置项类型。 :param secret: 是否为密钥配置项。(默认为False) :param table_name: 配置项表名。 - :param _global: 是否搜索所有表的配置项,仅内部使用。(默认为False) - :param _generate: 是否标记为生成配置文件,仅内部使用。(默认为False) :return: 配置文件中对应配置项的值。 - ''' + ''' cls.watch() q = q.lower() value = None @@ -218,7 +217,6 @@ def write(cls, q: str, value: Union[Any, None], cfg_type: Union[type, tuple, Non :param cfg_type: 配置项类型。 :param secret: 是否为密钥配置项。(默认为False) :param table_name: 配置项表名。 - :param _generate: 是否标记为生成配置文件,仅内部使用。(默认为False) ''' cls.watch() q = q.lower() @@ -351,7 +349,7 @@ def delete(cls, q: str, table_name: Optional[str] = None) -> bool: :param q: 配置项键名。 :param table_name: 配置项表名。 - ''' + ''' cls.watch() q = q.lower() found = False @@ -406,8 +404,6 @@ def Config(q: str, :param secret: 是否为密钥配置项。(默认为False) :param table_name: 配置项表名。 :param get_url: 是否为URL配置项。(默认为False) - :param _global: 是否搜索所有表的配置项,仅内部使用。(默认为False) - :param _generate: 是否标记为生成配置文件,仅内部使用。(默认为False) :return: 配置文件中对应配置项的值。 ''' diff --git a/core/constants/info.py b/core/constants/info.py index 183b8aa236..0cdff65f0b 100644 --- a/core/constants/info.py +++ b/core/constants/info.py @@ -9,6 +9,20 @@ def add(cls, secret): class Info: + ''' + 机器人信息。 + + :param version: 机器人版本。 + :param subprocess: 是否为子进程。 + :param binary_mode: 是否为二进制模式。 + :param command_parsed: 已处理命令数量。 + :param client_name: 客户端名称。 + :param dirty_word_check: 是否启用文本过滤。 + :param web_render_status: WebRender 状态。 + :param web_render_local_status: 本地 WebRender 状态。 + :param use_url_manager: 是否启用 URLManager。 + :param use_url_md_format: 是否启用 URL MarkDown 格式。 + ''' version = None subprocess = False binary_mode = False diff --git a/core/joke.py b/core/joke.py index 9acea3dc0a..9e91f5bab0 100644 --- a/core/joke.py +++ b/core/joke.py @@ -5,13 +5,19 @@ from core.utils.http import url_pattern -def joke(text: str) -> str: +def check_apr_fools() -> bool: current_date = datetime.now().date() enable_joke = Config('enable_joke', True, cfg_type=bool) - if enable_joke and (current_date.month == 4 and current_date.day == 1): + return (enable_joke and (current_date.month == 4 and current_date.day == 1)) + + +def joke(text: str) -> str: + if check_apr_fools(): + # 这里可能会增加使用不同玩笑方法的区分,但现在不太想做XD return shuffle_joke(text) - return text + else: + return text def shuffle_joke(text: str) -> str: @@ -37,6 +43,3 @@ def shuffle_joke(text: str) -> str: text_list[j], text_list[j + 1] = text_list[j + 1], text_list[j] parts[i] = ''.join(text_list) return ''.join(parts) - - -__all__ = ['joke', 'shuffle_joke'] diff --git a/core/parser/message.py b/core/parser/message.py index e88e3ded36..cb550f7852 100644 --- a/core/parser/message.py +++ b/core/parser/message.py @@ -489,10 +489,10 @@ async def execute_submodule(msg: Bot.MessageSession, command_first_word, command if Config('bug_report_url', bug_report_url_default, cfg_type=str): errmsg += '\n' + msg.locale.t('error.prompt.address', - url=str(Url(Config('bug_report_url', - bug_report_url_default, - cfg_type=str), - use_mm=False))) + url=Url(Config('bug_report_url', + bug_report_url_default, + cfg_type=str), + use_mm=False)) await msg.send_message(errmsg) if not timeout and report_targets: diff --git a/core/types/message/__init__.py b/core/types/message/__init__.py index c6268c1df6..f9542bdd5e 100644 --- a/core/types/message/__init__.py +++ b/core/types/message/__init__.py @@ -7,8 +7,8 @@ @define class MsgInfo: - target_id: Union[int, str] - sender_id: Union[int, str] + target_id: str + sender_id: str sender_prefix: str target_from: str sender_from: str diff --git a/core/types/module/__init__.py b/core/types/module/__init__.py index 69538759dd..bfe21f07f0 100644 --- a/core/types/module/__init__.py +++ b/core/types/module/__init__.py @@ -7,57 +7,47 @@ from .component_matches import * +from .utils import convert2lst -def convert2lst(elements: Union[str, list, tuple]) -> list: - if isinstance(elements, str): - return [elements] - elif isinstance(elements, tuple): - return list(elements) - return elements +from attrs import define, field, Converter +from copy import deepcopy + +def alias_converter(value, _self) -> dict: + if isinstance(value, str): + return {value: _self.bind_prefix} + elif isinstance(value, (tuple, list)): + return {x: _self.bind_prefix for x in value} + return value + + +@define class Module: - def __init__(self, - bind_prefix: str, - alias: Union[str, list, tuple, dict, None] = None, - desc: str = None, - recommend_modules: Union[str, list, tuple, None] = None, - developers: Union[str, list, tuple, None] = None, - required_admin: bool = False, - base: bool = False, - doc: bool = False, - hidden: bool = False, - load: bool = True, - rss: bool = False, - required_superuser: bool = False, - required_base_superuser: bool = False, - available_for: Union[str, list, tuple, None] = '*', - exclude_from: Union[str, list, tuple, None] = '', - support_languages: Union[str, list, tuple, None] = None): - self.bind_prefix: str = bind_prefix - if isinstance(alias, str): - alias = {alias: bind_prefix} - elif isinstance(alias, (tuple, list)): - alias = {x: bind_prefix for x in alias} - self.alias: Dict[str, str] = alias - self.desc: str = desc - self.recommend_modules: List[str] = convert2lst(recommend_modules) - self.developers: List[str] = convert2lst(developers) - self.required_admin: bool = required_admin - self.base: bool = base - self.doc: bool = doc - self.hidden: bool = hidden - self.load: bool = load - self.rss: bool = rss - self.required_superuser: bool = required_superuser - self.required_base_superuser: bool = required_base_superuser - self.available_for: List[str] = convert2lst(available_for) - self.exclude_from: List[str] = convert2lst(exclude_from) - self.support_languages: List[str] = convert2lst(support_languages) - self.command_list = CommandMatches() - self.regex_list = RegexMatches() - self.schedule_list = ScheduleMatches() - self.hooks_list = HookMatches() + bind_prefix: str + alias: dict = field(converter=Converter(alias_converter, takes_self=True)) + recommend_modules: list = field(converter=convert2lst) + developers: list = field(converter=convert2lst) + available_for: list = field(default=['*'], converter=convert2lst) + exclude_from: list = field(default=[], converter=convert2lst) + support_languages: list = field(default=None, converter=convert2lst) + desc: Union[str] = '' + required_admin: bool = False + base: bool = False + doc: bool = False + hidden: bool = False + load: bool = True + rss: bool = False + required_superuser: bool = False + required_base_superuser: bool = False + command_list: CommandMatches = CommandMatches.init() + regex_list: RegexMatches = RegexMatches.init() + schedule_list: ScheduleMatches = ScheduleMatches.init() + hooks_list: HookMatches = HookMatches.init() + + @classmethod + def assign(cls, **kwargs): + return deepcopy(cls(**kwargs)) __all__ = ["Module", "AndTrigger", "OrTrigger", "DateTrigger", diff --git a/core/types/module/component_matches.py b/core/types/module/component_matches.py index d272904449..11c1bfca41 100644 --- a/core/types/module/component_matches.py +++ b/core/types/module/component_matches.py @@ -2,15 +2,27 @@ from .component_meta import * +from attrs import define, field +from copy import deepcopy -class CommandMatches: - def __init__(self): - self.set: List[CommandMeta] = [] + +@define +class BaseMatches: + set: List[ModuleMeta] = [] def add(self, meta): self.set.append(meta) return self.set + @classmethod + def init(cls): + return deepcopy(cls()) + + +@define +class CommandMatches(BaseMatches): + set: List[CommandMeta] = [] + def get(self, target_from: str, show_required_superuser: bool = False, show_required_base_superuser: bool = False) -> List[CommandMeta]: metas = [] @@ -28,13 +40,9 @@ def get(self, target_from: str, show_required_superuser: bool = False, return metas -class RegexMatches: - def __init__(self): - self.set: List[RegexMeta] = [] - - def add(self, meta): - self.set.append(meta) - return self.set +@define +class RegexMatches(BaseMatches): + set: List[RegexMeta] = [] def get(self, target_from: str, show_required_superuser: bool = False, show_required_base_superuser: bool = False) -> List[RegexMeta]: @@ -53,22 +61,14 @@ def get(self, target_from: str, show_required_superuser: bool = False, return metas -class ScheduleMatches: - def __init__(self): - self.set: List[ScheduleMeta] = [] - - def add(self, meta): - self.set.append(meta) - return self.set +@define +class ScheduleMatches(BaseMatches): + set: List[ScheduleMeta] = [] -class HookMatches: - def __init__(self): - self.set: List[HookMeta] = [] - - def add(self, meta): - self.set.append(meta) - return self.set +@define +class HookMatches(BaseMatches): + set: List[HookMeta] = [] __all__ = ["CommandMatches", "RegexMatches", "ScheduleMatches", "HookMatches"] diff --git a/core/types/module/component_meta.py b/core/types/module/component_meta.py index a35d8cc469..36f5678a8b 100644 --- a/core/types/module/component_meta.py +++ b/core/types/module/component_meta.py @@ -8,100 +8,55 @@ from core.parser.args import Template - -class Meta: - def __init__(self, **kwargs): - raise NotImplementedError - - -class CommandMeta: - def __init__(self, - function: Callable = None, - help_doc: List[Template] = None, - options_desc: dict = None, - required_admin: bool = False, - required_superuser: bool = False, - required_base_superuser: bool = False, - available_for: Union[str, list, tuple] = '*', - exclude_from: Union[str, list, tuple] = '', - load: bool = True, - priority: int = 1 - ): - self.function = function - if isinstance(help_doc, str): - help_doc = [help_doc] - elif isinstance(help_doc, tuple): - help_doc = list(help_doc) - self.help_doc: List[Template] = help_doc - self.options_desc = options_desc - self.required_admin = required_admin - self.required_superuser = required_superuser - self.required_base_superuser = required_base_superuser - if isinstance(available_for, str): - available_for = [available_for] - elif isinstance(available_for, tuple): - available_for = list(available_for) - if isinstance(exclude_from, str): - exclude_from = [exclude_from] - elif isinstance(exclude_from, tuple): - exclude_from = list(exclude_from) - self.available_for = available_for - self.exclude_from = exclude_from - self.load = load - self.priority = priority - - -class RegexMeta: - def __init__(self, - function: Callable = None, - pattern: Union[str, re.Pattern] = None, - mode: str = None, - desc: str = None, - required_admin: bool = False, - required_superuser: bool = False, - required_base_superuser: bool = False, - available_for: Union[str, list, tuple] = '*', - exclude_from: Union[str, list, tuple] = '', - flags: re.RegexFlag = 0, - load: bool = True, - show_typing: bool = True, - logging: bool = True - ): - self.function = function - self.pattern = pattern - self.mode = mode - self.flags = flags - self.desc = desc - self.required_admin = required_admin - self.required_superuser = required_superuser - self.required_base_superuser = required_base_superuser - if isinstance(available_for, str): - available_for = [available_for] - elif isinstance(available_for, tuple): - available_for = list(available_for) - if isinstance(exclude_from, str): - exclude_from = [exclude_from] - elif isinstance(exclude_from, tuple): - exclude_from = list(exclude_from) - self.available_for = available_for - self.exclude_from = exclude_from - self.load = load - self.show_typing = show_typing - self.logging = logging - - -class ScheduleMeta: - def __init__(self, trigger: Union[AndTrigger, OrTrigger, DateTrigger, CronTrigger, IntervalTrigger], - function: Callable = None - ): - self.trigger = trigger - self.function = function - - -class HookMeta: - def __init__(self, function: Callable, name: str = None): - self.function = function - self.name = name - - -__all__ = ["Meta", "CommandMeta", "RegexMeta", "ScheduleMeta", "HookMeta"] +from attrs import define, field +from .utils import convert2lst + + +class ModuleMeta: + pass + + +@define +class CommandMeta(ModuleMeta): + function: Callable = None + help_doc: List[Template] = field(default=[], converter=convert2lst) + options_desc: dict = None + required_admin: bool = False + required_superuser: bool = False + required_base_superuser: bool = False + available_for: list = field(default=['*'], converter=convert2lst) + exclude_from: list = field(default=[], converter=convert2lst) + load: bool = True + priority: int = 1 + + +@define +class RegexMeta(ModuleMeta): + function: Callable = None + pattern: Union[str, re.Pattern] = None + mode: str = None + desc: str = None + required_admin: bool = False + required_superuser: bool = False + required_base_superuser: bool = False + available_for: list = field(default=['*'], converter=convert2lst) + exclude_from: list = field(default=[], converter=convert2lst) + flags: re.RegexFlag = 0 + load: bool = True + show_typing: bool = True + logging: bool = True + + +@define +class ScheduleMeta(ModuleMeta): + trigger: Union[AndTrigger, OrTrigger, DateTrigger, CronTrigger, IntervalTrigger] + function: Callable = None + + +@define +class HookMeta(ModuleMeta): + function: Callable = None + name: str = None + + +__all__ = ["ModuleMeta", "CommandMeta", "RegexMeta", "ScheduleMeta", "HookMeta"] diff --git a/core/types/module/utils.py b/core/types/module/utils.py new file mode 100644 index 0000000000..e3696498b0 --- /dev/null +++ b/core/types/module/utils.py @@ -0,0 +1,9 @@ +from typing import Union + + +def convert2lst(elements: Union[str, list, tuple]) -> list: + if isinstance(elements, str): + return [elements] + elif isinstance(elements, tuple): + return list(elements) + return elements diff --git a/core/utils/cooldown.py b/core/utils/cooldown.py index 1e474b649a..e66cbe6d26 100644 --- a/core/utils/cooldown.py +++ b/core/utils/cooldown.py @@ -1,50 +1,65 @@ -import datetime -from typing import Dict +from collections import defaultdict +from datetime import datetime from core.builtins import MessageSession -_cd_lst: Dict[str, Dict[MessageSession, float]] = {} +_cd_lst = defaultdict(lambda: defaultdict(dict)) class CoolDown: + ''' + 冷却事件构造器。 + + :param key: 冷却事件名称。 + :param msg: 消息会话。 + :param all: 是否应用至全对话。(默认为False) + ''' def __init__(self, key: str, msg: MessageSession, all: bool = False): self.key = key self.msg = msg - self.sender_id = self.msg - if isinstance(self.sender_id, MessageSession): - if all: - self.sender_id = self.msg.target.target_id - else: - self.sender_id = self.sender_id.target.sender_id + self.all = all + self.target_id = self.msg.target.target_id + self.sender_id = self.msg.target.sender_id + + def _get_cd_dict(self): + target_dict = _cd_lst[self.target_id] + if self.all: + return target_dict.setdefault(self.key, {'_timestamp': 0.0}) + else: + sender_dict = target_dict.setdefault(self.sender_id, {}) + return sender_dict.setdefault(self.key, {'_timestamp': 0.0}) def add(self): ''' 添加冷却事件。 ''' - if self.key not in _cd_lst: - _cd_lst[self.key] = {} - _cd_lst[self.key][self.sender_id] = datetime.datetime.now().timestamp() + cooldown_dict = self._get_cd_dict() + cooldown_dict['_timestamp'] = datetime.now().timestamp() - def check(self, delay: int) -> float: + def check(self, delay: float) -> float: ''' 检查冷却事件剩余冷却时间。 + + :param delay: 设定的冷却时间。 + :return: 剩余的冷却时间。 ''' if self.key not in _cd_lst: return 0 - if self.sender_id in _cd_lst[self.key]: - if (d := (datetime.datetime.now().timestamp() - _cd_lst[self.key][self.sender_id])) > delay: - return 0 - else: - return d + target_dict = _cd_lst[self.target_id] + if self.all: + ts = target_dict.get(self.key, {}).get('_timestamp', 0.0) else: + sender_dict = target_dict.get(self.sender_id, {}) + ts = sender_dict.get(self.key, {}).get('_timestamp', 0.0) + + if (d := (datetime.now().timestamp() - ts)) > delay: return 0 + else: + return d def reset(self): ''' 重置冷却事件。 ''' - if self.key in _cd_lst: - if self.sender_id in _cd_lst[self.key]: - _cd_lst[self.key].pop(self.sender_id) self.add() diff --git a/core/utils/game.py b/core/utils/game.py index 1cfc140cd7..39c32c9983 100644 --- a/core/utils/game.py +++ b/core/utils/game.py @@ -1,15 +1,23 @@ from collections import defaultdict from datetime import datetime -from typing import Any, Optional +from typing import Any, Dict, Optional from core.builtins import MessageSession from core.logger import Logger -playstate_lst = defaultdict(lambda: defaultdict(dict)) +_ps_lst = defaultdict(lambda: defaultdict(dict)) GAME_EXPIRED = 3600 class PlayState: + ''' + 游戏事件构造器。 + + :param game: 游戏事件名称。 + :param msg: 消息会话。 + :param all: 是否应用至全对话。(默认为False) + ''' + def __init__(self, game: str, msg: MessageSession, all: bool = False): self.game = game self.msg = msg @@ -17,8 +25,8 @@ def __init__(self, game: str, msg: MessageSession, all: bool = False): self.target_id = self.msg.target.target_id self.sender_id = self.msg.target.sender_id - def _get_game_dict(self): - target_dict = playstate_lst[self.target_id] + def _get_ps_dict(self): + target_dict = _ps_lst[self.target_id] if self.all: return target_dict.setdefault(self.game, {'_status': False, '_timestamp': 0.0}) else: @@ -29,21 +37,21 @@ def enable(self) -> None: ''' 开启游戏事件。 ''' - game_dict = self._get_game_dict() - game_dict['_status'] = True - game_dict['_timestamp'] = datetime.now().timestamp() + playstate_dict = self._get_ps_dict() + playstate_dict['_status'] = True + playstate_dict['_timestamp'] = datetime.now().timestamp() if self.all: Logger.info(f'[{self.target_id}]: Enabled {self.game} by {self.sender_id}.') else: Logger.info(f'[{self.sender_id}]: Enabled {self.game} at {self.target_id}.') - def disable(self, auto=False) -> None: + def disable(self, _auto=False) -> None: ''' 关闭游戏事件。 ''' - if self.target_id not in playstate_lst: + if self.target_id not in _ps_lst: return - target_dict = playstate_lst[self.target_id] + target_dict = _ps_lst[self.target_id] if self.all: game_dict = target_dict.get(self.game) if game_dict: @@ -54,7 +62,7 @@ def disable(self, auto=False) -> None: game_dict = sender_dict.get(self.game) if game_dict: game_dict['_status'] = False - if auto: + if _auto: if self.all: Logger.info(f'[{self.target_id}]: Disabled {self.game} automatically.') else: @@ -65,12 +73,14 @@ def disable(self, auto=False) -> None: else: Logger.info(f'[{self.sender_id}]: Disabled {self.game} at {self.target_id}.') - def update(self, **kwargs) -> None: + def update(self, **kwargs: Dict[str, Any]) -> None: ''' 更新游戏事件中需要的值。 + + :param kwargs: 键值对。 ''' - game_dict = self._get_game_dict() - game_dict.update(kwargs) + playstate_dict = self._get_ps_dict() + playstate_dict.update(kwargs) if self.all: Logger.debug(f'[{self.game}]: Updated {str(kwargs)} at {self.target_id}.') else: @@ -80,9 +90,9 @@ def check(self) -> bool: ''' 检查游戏事件状态,若超过时间则自动关闭。 ''' - if self.target_id not in playstate_lst: + if self.target_id not in _ps_lst: return False - target_dict = playstate_lst[self.target_id] + target_dict = _ps_lst[self.target_id] if self.all: status = target_dict.get(self.game, {}).get('_status', False) ts = target_dict.get(self.game, {}).get('_timestamp', 0.0) @@ -91,16 +101,19 @@ def check(self) -> bool: status = sender_dict.get(self.game, {}).get('_status', False) ts = sender_dict.get(self.game, {}).get('_timestamp', 0.0) if datetime.now().timestamp() - ts >= GAME_EXPIRED: - self.disable(auto=True) + self.disable(_auto=True) return status def get(self, key: str) -> Optional[Any]: ''' 获取游戏事件中需要的值。 + + :param key: 键名。 + :return: 值。 ''' - if self.target_id not in playstate_lst: + if self.target_id not in _ps_lst: return None - target_dict = playstate_lst[self.target_id] + target_dict = _ps_lst[self.target_id] if self.all: return target_dict.get(self.game, {}).get(key, None) else: diff --git a/core/utils/http.py b/core/utils/http.py index f8bf345933..63584d21c0 100644 --- a/core/utils/http.py +++ b/core/utils/http.py @@ -44,11 +44,11 @@ def private_ip_check(url: str): async def get_url(url: str, - status_code: int = False, + status_code: int = 200, headers: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None, fmt: Optional[str] = None, - timeout: int = 20, + timeout: Optional[float] = 20, attempt: int = 3, request_private_ip: bool = False, logging_err_resp: bool = True, @@ -113,10 +113,10 @@ async def get_(): async def post_url(url: str, data: Any = None, - status_code: int = False, + status_code: int = 200, headers: Optional[Dict[str, Any]] = None, fmt: Optional[str] = None, - timeout: int = 20, + timeout: Optional[float] = 20, attempt: int = 3, request_private_ip: bool = False, logging_err_resp: bool = True, @@ -182,11 +182,11 @@ async def _post(): async def download(url: str, filename: Optional[str] = None, path: Optional[str] = None, - status_code: int = False, + status_code: int = 200, method: str = "GET", post_data: Any = None, headers: Optional[Dict[str, Any]] = None, - timeout: int = 20, + timeout: Optional[float] = 20, attempt: int = 3, request_private_ip: bool = False, logging_err_resp: bool = True) -> Union[str, bool]: diff --git a/core/utils/i18n.py b/core/utils/i18n.py index f4a1cfd9df..5d117183d9 100644 --- a/core/utils/i18n.py +++ b/core/utils/i18n.py @@ -115,8 +115,11 @@ def get_available_locales() -> List[str]: class Locale: + """ + 创建一个本地化对象。 + """ + def __init__(self, locale: str, fallback_lng: Optional[List[str]] = None): - """创建一个本地化对象。""" if not fallback_lng: fallback_lng = supported_locales.copy() fallback_lng.remove(locale) diff --git a/core/utils/image_table.py b/core/utils/image_table.py index 34e86b8dcd..8c6ad57222 100644 --- a/core/utils/image_table.py +++ b/core/utils/image_table.py @@ -2,7 +2,7 @@ import re from html import escape from io import BytesIO -from typing import List, Union +from typing import Any, List, Union import aiohttp import orjson as json @@ -18,12 +18,28 @@ class ImageTable: - def __init__(self, data, headers): + ''' + 图片表格。 + :param data: 表格内容,表格行数需与表格标头的数量相符。 + :param headers: 表格表头。 + ''' + + def __init__(self, data: List[List[Any]], headers: List[str]): self.data = data self.headers = headers -async def image_table_render(table: Union[ImageTable, List[ImageTable]], save_source=True, use_local=True) -> Union[List[PILImage], bool]: +async def image_table_render(table: Union[ImageTable, List[ImageTable]], + save_source: bool = True, + use_local: bool = True) -> Union[List[PILImage.Image], bool]: + ''' + 使用WebRender渲染图片表格。 + + :param table: 要渲染的表格。 + :param save_source: 是否保存源文件。 + :param use_local: 是否使用本地WebRender渲染。 + :return: 图片的PIL对象。 + ''' if not Info.web_render_status: return False elif not Info.web_render_local_status: @@ -70,7 +86,7 @@ async def image_table_render(table: Union[ImageTable, List[ImageTable]], save_so try: pic = await download( - webrender(use_local=use_local), + webrender(use_local=use_local, method=''), method='POST', post_data=json.dumps(html), request_private_ip=True, @@ -81,7 +97,7 @@ async def image_table_render(table: Union[ImageTable, List[ImageTable]], save_so except aiohttp.ClientConnectorError: if use_local: pic = await download( - webrender(use_local=False), + webrender(use_local=False, method=''), method='POST', post_data=json.dumps(html), request_private_ip=True, diff --git a/core/utils/random.py b/core/utils/random.py index 2d8797afe3..d23f76b5d2 100644 --- a/core/utils/random.py +++ b/core/utils/random.py @@ -9,6 +9,9 @@ class Random: + """ + 机器人内置的随机数生成器。在配置文件中将`use_secrets_random`设为`true`时使用`secret`库,否则默认使用`random`库。 + """ use_secrets = Config('use_secrets_random', False) @classmethod diff --git a/core/utils/text.py b/core/utils/text.py index c3ab0f9c72..13615a89ac 100644 --- a/core/utils/text.py +++ b/core/utils/text.py @@ -4,6 +4,9 @@ def isfloat(num_str: Any) -> bool: + ''' + 检查字符串是否符合float。 + ''' try: float(num_str) return True @@ -12,6 +15,9 @@ def isfloat(num_str: Any) -> bool: def isint(num_str: Any) -> bool: + ''' + 检查字符串是否符合int。 + ''' try: int(num_str) return True @@ -40,8 +46,4 @@ def parse_time_string(time_str: str) -> timedelta: return timedelta() -def random_string(length: int) -> str: - return ''.join(random.choices("0123456789ABCDEF", k=length)) - - -__all__ = ["parse_time_string", "random_string", "isint", "isfloat"] +__all__ = ["isint", "isfloat"] diff --git a/core/utils/web_render.py b/core/utils/web_render.py index 8cd6975982..579a2987d8 100644 --- a/core/utils/web_render.py +++ b/core/utils/web_render.py @@ -1,5 +1,5 @@ import traceback -from typing import Tuple, Union +from typing import Tuple, Optional, Union from core.config import Config from core.constants.info import Info @@ -10,7 +10,7 @@ web_render_local = Config('web_render_local', get_url=True) -def webrender(method: str = '', url: str = '', use_local: bool = True) -> Union[str, None]: +def webrender(method: str = 'source', url: Optional[str] = None, use_local: bool = True, _ignore_status=False) -> str: '''根据请求方法生成 WebRender URL。 :param method: API 方法。 @@ -18,18 +18,19 @@ def webrender(method: str = '', url: str = '', use_local: bool = True) -> Union[ :param use_local: 是否使用本地 WebRender。 :returns: 生成的 WebRender URL。 ''' - if use_local and not Info.web_render_local_status: + if use_local and (not Info.web_render_local_status or _ignore_status): use_local = False if method == 'source': - if Info.web_render_status: + url = '' if not url else url + if Info.web_render_status or _ignore_status: return f'{(web_render_local if use_local else web_render)}source?url={url}' else: return url else: - if Info.web_render_status: + if Info.web_render_status or _ignore_status: return (web_render_local if use_local else web_render) + method else: - return None + return '' async def check_web_render() -> Tuple[bool, bool]: @@ -47,7 +48,7 @@ async def check_web_render() -> Tuple[bool, bool]: if web_render_status: try: Logger.info('[WebRender] Checking WebRender status...') - await get_url(webrender('source', ping_url), 200, request_private_ip=True) + await get_url(webrender('source', ping_url, _ignore_status=True), 200, request_private_ip=True) Logger.info('[WebRender] WebRender is working as expected.') except Exception: Logger.error('[WebRender] WebRender is not working as expected.') diff --git a/modules/cytoid/profile.py b/modules/cytoid/profile.py index 15b9aac745..09862f722e 100644 --- a/modules/cytoid/profile.py +++ b/modules/cytoid/profile.py @@ -4,8 +4,14 @@ from core.utils.http import get_url -async def cytoid_profile(msg: Bot.MessageSession, uid): - profile_url = 'http://services.cytoid.io/profile/' + uid +async def cytoid_profile(msg: Bot.MessageSession, username): + if username: + query_id = username.lower() + else: + query_id = CytoidBindInfoManager(msg).get_bind_username() + if not query_id: + await msg.finish(msg.locale.t('cytoid.message.user_unbound', prefix=msg.prefixes[0])) + profile_url = f'http://services.cytoid.io/profile/{query_id}' try: profile = json.loads(await get_url(profile_url, 200)) except ValueError as e: diff --git a/modules/cytoid/rating.py b/modules/cytoid/rating.py index 3da8da601c..52a6dcb914 100644 --- a/modules/cytoid/rating.py +++ b/modules/cytoid/rating.py @@ -13,11 +13,11 @@ from core.builtins import Bot from core.config import Config -from core.constants.path import assets_path, noto_sans_demilight_path, nunito_regular_path, nunito_light_path +from core.constants.path import assets_path, cache_path, noto_sans_demilight_path, nunito_regular_path, nunito_light_path from core.logger import Logger from core.utils.cache import random_cache_path from core.utils.html2text import html2text -from core.utils.http import get_url +from core.utils.http import get_url, download from core.utils.image import get_fontsize from core.utils.text import parse_time_string @@ -28,7 +28,7 @@ async def get_rating(msg: Bot.MessageSession, uid, query_type): query_type = 'bestRecords' elif query_type == 'r30': query_type = 'recentRecords' - profile_url = 'http://services.cytoid.io/profile/' + uid + profile_url = f'http://services.cytoid.io/profile/{uid}' profile_json = json.loads(await get_url(profile_url, 200)) if 'statusCode' in profile_json: if profile_json['statusCode'] == 404: @@ -219,35 +219,29 @@ async def mkresources(msg: Bot.MessageSession, x, rank): async def download_cover_thumb(uid): try: - d = os.path.join(assets_path, 'cytoid-cover', uid) + filename = 'thumbnail.png' + d = os.path.join(cache_path, 'cytoid-cover', uid) os.makedirs(d, exist_ok=True) - path = os.path.join(d, 'thumbnail.png') + path = os.path.join(d, filename) if not os.path.exists(path): - level_url = 'http://services.cytoid.io/levels/' + uid + level_url = f'http://services.cytoid.io/levels/{uid}' get_level = json.loads(await get_url(level_url)) - cover_thumbnail = get_level['cover']['original'] + "?h=240&w=384" - async with aiohttp.ClientSession() as session, session.get(cover_thumbnail) as resp, async_open(path, 'wb+') as jpg: - await jpg.write(await resp.read()) - return path - else: - return path - except BaseException: + cover_thumbnail = f"{get_level['cover']['original']}?h=240&w=384" + path = await download(cover_thumbnail, filename=filename, path=d, logging_err_resp=False) + return path + except Exception: Logger.error(traceback.format_exc()) return False async def download_avatar_thumb(link, id): - Logger.debug(f'Downloading avatar for {str(id)}') + Logger.debug(f'Downloading avatar for {id}') try: - d = os.path.join(assets_path, 'cytoid-avatar') + d = os.path.join(cache_path, 'cytoid-avatar') os.makedirs(d, exist_ok=True) - path = os.path.join(d, f'{id}.png') - if os.path.exists(path): - os.remove(path) - async with aiohttp.ClientSession() as session, session.get(link, timeout=aiohttp.ClientTimeout(total=20)) as resp, async_open(path, 'wb+') as jpg: - await jpg.write(await resp.read()) - return path - except BaseException: + path = await download(link, filename=f'{id}.png', path=d, logging_err_resp=False) + return path + except Exception: Logger.error(traceback.format_exc()) return False diff --git a/modules/cytoid/utils.py b/modules/cytoid/utils.py index 7ebdf2e48f..2631f8935b 100644 --- a/modules/cytoid/utils.py +++ b/modules/cytoid/utils.py @@ -5,7 +5,7 @@ async def get_profile_name(userid): try: - profile_url = 'http://services.cytoid.io/profile/' + userid + profile_url = f'http://services.cytoid.io/profile/{userid}' profile = json.loads(await get_url(profile_url, 200)) except BaseException: return False diff --git a/modules/maimai/libraries/chunithm_apidata.py b/modules/maimai/libraries/chunithm_apidata.py index ce3170dbe2..823e45e5ba 100644 --- a/modules/maimai/libraries/chunithm_apidata.py +++ b/modules/maimai/libraries/chunithm_apidata.py @@ -22,7 +22,9 @@ async def get_info(music: Music, *details) -> MessageChain: async def get_record(msg: Bot.MessageSession, payload: dict, use_cache: bool = True) -> Optional[str]: - cache_dir = os.path.join(cache_path, f'{msg.target.sender_id.replace('|', '_')}_maimai_record.json') + dir = os.path.join(cache_path, 'maimai-record') + os.makedirs(dir, exist_ok=True) + cache_dir = os.path.join(dir, f'{msg.target.sender_id.replace('|', '_')}_chunithm_record.json') url = "https://www.diving-fish.com/api/chunithmprober/query/player" if 'username' in payload: use_cache = False diff --git a/modules/maimai/libraries/maimaidx_apidata.py b/modules/maimai/libraries/maimaidx_apidata.py index f43441b31a..74eb9038b4 100644 --- a/modules/maimai/libraries/maimaidx_apidata.py +++ b/modules/maimai/libraries/maimaidx_apidata.py @@ -19,21 +19,21 @@ async def update_cover() -> bool: - id_list = ['00000', '01000'] + id_list = ['0', '1000'] for song in (await total_list.get()): id_list.append(song['id']) os.makedirs(mai_cover_path, exist_ok=True) for id in id_list: - cover_path = os.path.join(mai_cover_path, f'{get_cover_len5_id(id)}.png') + cover_path = os.path.join(mai_cover_path, f'{id}.png') if not os.path.exists(cover_path): try: url = f"https://www.diving-fish.com/covers/{get_cover_len5_id(id)}.png" - await download(url, status_code=200, path=mai_cover_path, filename=f'{get_cover_len5_id(id)}.png', attempt=1, logging_err_resp=False) - Logger.debug(f'Successfully download {get_cover_len5_id(id)}.png') + await download(url, status_code=200, path=mai_cover_path, filename=f'{id}.png', attempt=1, logging_err_resp=False) + Logger.debug(f'Successfully download {id}.png') except Exception as e: if str(e).startswith('404'): if Config('debug', False): - Logger.error(f'Failed to download {get_cover_len5_id(id)}.png') + Logger.error(f'Failed to download {id}.png') continue else: Logger.error(traceback.format_exc()) @@ -56,11 +56,11 @@ async def update_alias() -> bool: async def get_info(music: Music, *details) -> MessageChain: info = [Plain(f"{music.id} - {music.title}{' (DX)' if music['type'] == 'DX' else ''}")] - cover_path = os.path.join(mai_cover_path, f'{get_cover_len5_id(music.id)}.png') + cover_path = os.path.join(mai_cover_path, f'{music.id}.png') if os.path.exists(cover_path): info.append(Image(cover_path)) else: - cover_path = os.path.join(mai_cover_path, '00000.png') + cover_path = os.path.join(mai_cover_path, '0.png') if os.path.exists(cover_path): info.append(Image(cover_path)) if details: @@ -107,7 +107,9 @@ async def search_by_alias(input_: str) -> list: async def get_record(msg: Bot.MessageSession, payload: dict, use_cache: bool = True) -> Optional[str]: - cache_dir = os.path.join(cache_path, f"{msg.target.sender_id.replace('|', '_')}_maimaidx_record.json") + dir = os.path.join(cache_path, 'maimai-record') + os.makedirs(dir, exist_ok=True) + cache_dir = os.path.join(dir, f"{msg.target.sender_id.replace('|', '_')}_maimaidx_record.json") url = "https://www.diving-fish.com/api/maimaidxprober/query/player" try: data = await post_url(url, @@ -148,7 +150,9 @@ async def get_record(msg: Bot.MessageSession, payload: dict, use_cache: bool = T async def get_song_record(msg: Bot.MessageSession, payload: dict, sid: Union[str, list[str]], use_cache: bool = True) -> Optional[str]: if DEVELOPER_TOKEN: - cache_dir = os.path.join(cache_path, f"{msg.target.sender_id.replace('|', '_')}_maimaidx_song_record.json") + dir = os.path.join(cache_path, 'maimai-record') + os.makedirs(dir, exist_ok=True) + cache_dir = os.path.join(dir, f"{msg.target.sender_id.replace('|', '_')}_maimaidx_song_record.json") url = "https://www.diving-fish.com/api/maimaidxprober/dev/player/record" try: payload.update({'music_id': sid}) @@ -193,7 +197,9 @@ async def get_song_record(msg: Bot.MessageSession, payload: dict, sid: Union[str async def get_total_record(msg: Bot.MessageSession, payload: dict, utage: bool = False, use_cache: bool = True): - cache_dir = os.path.join(cache_path, f"{msg.target.sender_id.replace('|', '_')}_maimaidx_total_record.json") + dir = os.path.join(cache_path, 'maimai-record') + os.makedirs(dir, exist_ok=True) + cache_dir = os.path.join(dir, f"{msg.target.sender_id.replace('|', '_')}_maimaidx_total_record.json") url = "https://www.diving-fish.com/api/maimaidxprober/query/plate" payload['version'] = versions try: @@ -237,7 +243,9 @@ async def get_total_record(msg: Bot.MessageSession, payload: dict, utage: bool = async def get_plate(msg: Bot.MessageSession, payload: dict, version: str, use_cache: bool = True) -> Optional[str]: version = '舞' if version == '覇' else version # “覇者”属于舞代 - cache_dir = os.path.join(cache_path, f"{msg.target.sender_id.replace('|', '_')}_maimaidx_plate_{version}.json") + dir = os.path.join(cache_path, 'maimai-record') + os.makedirs(dir, exist_ok=True) + cache_dir = os.path.join(dir, f"{msg.target.sender_id.replace('|', '_')}_maimaidx_plate_{version}.json") url = "https://www.diving-fish.com/api/maimaidxprober/query/plate" try: data = await post_url(url, diff --git a/modules/maimai/libraries/maimaidx_best50.py b/modules/maimai/libraries/maimaidx_best50.py index 38f54d9e11..8aadc71583 100644 --- a/modules/maimai/libraries/maimaidx_best50.py +++ b/modules/maimai/libraries/maimaidx_best50.py @@ -7,7 +7,7 @@ from core.constants.path import noto_sans_demilight_path, noto_sans_symbol_path from .maimaidx_apidata import get_record from .maimaidx_mapping import mai_cover_path, rate_mapping, combo_mapping, sync_mapping, diff_list -from .maimaidx_music import get_cover_len5_id, TotalList +from .maimaidx_music import TotalList from .maimaidx_utils import compute_rating, calc_dxstar total_list = TotalList() @@ -143,9 +143,9 @@ def _drawBestList(self, img: Image.Image, sdBest: BestList, dxBest: BestList): i = num // 5 j = num % 5 chartInfo = sdBest[num] - pngPath = os.path.join(mai_cover_path, f'{get_cover_len5_id(chartInfo.idNum)}.png') + pngPath = os.path.join(mai_cover_path, f'{chartInfo.idNum}.png') if not os.path.exists(pngPath): - pngPath = os.path.join(mai_cover_path, '01000.png') + pngPath = os.path.join(mai_cover_path, '0.png') if os.path.exists(pngPath): temp = Image.open(pngPath).convert('RGB') @@ -195,9 +195,9 @@ def _drawBestList(self, img: Image.Image, sdBest: BestList, dxBest: BestList): i = num // 5 j = num % 5 chartInfo = dxBest[num] - pngPath = os.path.join(mai_cover_path, f'{get_cover_len5_id(chartInfo.idNum)}.png') + pngPath = os.path.join(mai_cover_path, f'{chartInfo.idNum}.png') if not os.path.exists(pngPath): - pngPath = os.path.join(mai_cover_path, '01000.png') + pngPath = os.path.join(mai_cover_path, '0.png') if os.path.exists(pngPath): temp = Image.open(pngPath).convert('RGB') diff --git a/modules/maimai/libraries/maimaidx_music.py b/modules/maimai/libraries/maimaidx_music.py index c305e461ad..9bfc441adb 100644 --- a/modules/maimai/libraries/maimaidx_music.py +++ b/modules/maimai/libraries/maimaidx_music.py @@ -10,7 +10,7 @@ from .maimaidx_mapping import * -def get_cover_len5_id(mid) -> str: +def get_cover_len5_id(mid: Union[int, str]) -> str: mid = int(mid) if 10000 < mid <= 11000: mid -= 10000 diff --git a/modules/meme/moegirl.py b/modules/meme/moegirl.py index f7a03ae837..a2cfc047e5 100644 --- a/modules/meme/moegirl.py +++ b/modules/meme/moegirl.py @@ -7,12 +7,12 @@ async def moegirl(term: str, locale: Locale): - result = await query_pages(QueryInfo('https://mzh.moegirl.org.cn/api.php', headers={'accept': '*/*', - 'accept-encoding': 'gzip, deflate', - 'accept-language': 'zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7,en-GB;q=0.6', - 'content-type': 'application/json', - 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36 Edg/96.0.1054.62'}, - locale=locale.locale), + result = await query_pages(QueryInfo.assign('https://mzh.moegirl.org.cn/api.php', headers={'accept': '*/*', + 'accept-encoding': 'gzip, deflate', + 'accept-language': 'zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7,en-GB;q=0.6', + 'content-type': 'application/json', + 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36 Edg/96.0.1054.62'}, + locale=locale.locale), term) msg = '' if result['msg_list']: @@ -25,13 +25,13 @@ async def moegirl(term: str, locale: Locale): r'(?<=是:\[)(.*?)(?=\]。)', msg_item.text).group(0) Logger.debug(redirect) if redirect: - wait = await query_pages(QueryInfo('https://mzh.moegirl.org.cn/api.php', headers={'accept': '*/*', - 'accept-encoding': 'gzip, deflate', - 'accept-language': 'zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7,en-GB;q=0.6', - 'content-type': 'application/json', - 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36 Edg/96.0.1054.62' - }, - locale=locale.locale), redirect) + wait = await query_pages(QueryInfo.assign('https://mzh.moegirl.org.cn/api.php', headers={'accept': '*/*', + 'accept-encoding': 'gzip, deflate', + 'accept-language': 'zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7,en-GB;q=0.6', + 'content-type': 'application/json', + 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36 Edg/96.0.1054.62' + }, + locale=locale.locale), redirect) msg += wait['msg_list'][0].text return f'[{locale.t("meme.message.moegirl")}] {msg}' diff --git a/modules/wiki/audit.py b/modules/wiki/audit.py index 525646178b..0b86cea5ad 100644 --- a/modules/wiki/audit.py +++ b/modules/wiki/audit.py @@ -121,8 +121,8 @@ async def _(msg: Bot.MessageSession): for im in block_image: send_msgs.append(Image(im)) if send_msgs: - await msg.finish(send_msgs) legacy = False + await msg.finish(send_msgs) if legacy: wikis = [] if allow_list: diff --git a/modules/wiki/inline.py b/modules/wiki/inline.py index f726c3d729..79f5d96fa5 100644 --- a/modules/wiki/inline.py +++ b/modules/wiki/inline.py @@ -201,7 +201,7 @@ async def _callback(msg: Bot.MessageSession): i_msg_lst.append(I18NContext('wiki.message.invalid_section.select')) i_msg_lst.append(I18NContext('message.reply.prompt')) - async def _callback(msg: Bot.MessageSession, forum_data=forum_data, get_page=get_page): + async def _callback(msg: Bot.MessageSession): display = msg.as_display(text_only=True) if isint(display) and int(display) <= len(forum_data) - 1: await query_pages(msg, title=forum_data[display]['text'], diff --git a/modules/wiki/search.py b/modules/wiki/search.py index 80f0b93f11..823426ffdc 100644 --- a/modules/wiki/search.py +++ b/modules/wiki/search.py @@ -15,7 +15,7 @@ async def _(msg: Bot.MessageSession, pagename: str): await search_pages(msg, pagename) -async def search_pages(msg: Bot.MessageSession, title: Union[str, list, tuple], use_prefix=True): +async def search_pages(msg: Bot.MessageSession, title: Union[str, list, tuple], use_prefix: bool = True): target = WikiTargetInfo(msg) start_wiki = target.get_start_wiki() interwiki_list = target.get_interwikis() diff --git a/modules/wiki/utils/mapping.py b/modules/wiki/utils/mapping.py index b9f9f1ef22..97d6ea0c63 100644 --- a/modules/wiki/utils/mapping.py +++ b/modules/wiki/utils/mapping.py @@ -2,11 +2,20 @@ # re.compile(r'.*runescape\.wiki'), ] -generate_screenshot_v2_blocklist = ['https://mzh.moegirl.org.cn', 'https://zh.moegirl.org.cn'] - -special_namespace_list = ['special', '特殊'] +infobox_elements = ['div#infoboxborder', + '.arcaeabox', + '.infobox', + '.infoboxtable', + '.infotemplatebox', + '.moe-infobox', + '.notaninfobox', + '.portable-infobox', + '.rotable', + '.skin-infobox', + '.tpl-infobox', + ] -random_title_list = ['random', '随机页面', '隨機頁面'] +generate_screenshot_v2_blocklist = ['https://mzh.moegirl.org.cn', 'https://zh.moegirl.org.cn'] redirect_list = {'https://zh.moegirl.org.cn/api.php': 'https://mzh.moegirl.org.cn/api.php', # 萌娘百科强制使用移动版 API 'https://minecraft.fandom.com/api.php': 'https://minecraft.wiki/api.php', # no more Fandom then diff --git a/modules/wiki/utils/screenshot_image.py b/modules/wiki/utils/screenshot_image.py index d5fa43cfcc..ab6520ad2b 100644 --- a/modules/wiki/utils/screenshot_image.py +++ b/modules/wiki/utils/screenshot_image.py @@ -17,14 +17,12 @@ from core.logger import Logger from core.utils.http import download from core.utils.web_render import webrender - -elements = ['.notaninfobox', '.portable-infobox', '.infobox', '.tpl-infobox', '.infoboxtable', '.infotemplatebox', - '.skin-infobox', '.arcaeabox', '.moe-infobox', '.rotable'] +from .mapping import infobox_elements async def generate_screenshot_v2(page_link: str, section: str = None, allow_special_page=False, content_mode=False, use_local=True, element=None) -> Union[List[PILImage], bool]: - elements_ = elements.copy() + elements_ = infobox_elements.copy() if element and isinstance(element, List): elements_ += element if not Info.web_render_status: diff --git a/modules/wiki/utils/wikilib.py b/modules/wiki/utils/wikilib.py index 1ab5b54dcd..3a8ba6e0f4 100644 --- a/modules/wiki/utils/wikilib.py +++ b/modules/wiki/utils/wikilib.py @@ -3,6 +3,7 @@ import re import traceback import urllib.parse +from copy import deepcopy from typing import Union, Dict, List import orjson as json @@ -21,6 +22,8 @@ from .dbutils import WikiSiteInfo as DBSiteInfo, Audit from .mapping import * +from attrs import define + default_locale = Config("default_locale", cfg_type=str) enable_tos = Config('enable_tos', True) @@ -41,107 +44,70 @@ class PageNotFound(Exception): pass +@define class QueryInfo: - def __init__(self, api, headers=None, prefix=None, locale=None): - self.api = api - self.headers = headers if headers else { - 'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6'} - self.prefix = prefix - self.locale = Locale(locale if locale else default_locale) - + api: str + headers: Dict[str, str] = { + 'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6'} + prefix: str = '' + locale: Locale = Locale(default_locale) -class WikiInfo: - def __init__(self, - api: str = '', - articlepath: str = '', - extensions=None, - interwiki=None, - realurl: str = '', - name: str = '', - namespaces=None, - namespaces_local=None, - namespacealiases=None, - in_allowlist=False, - in_blocklist=False, - script: str = '', - logo_url: str = ''): - if not extensions: - extensions = [] - if not interwiki: - interwiki = {} - self.api = api - self.articlepath = articlepath - self.extensions = extensions - self.interwiki = interwiki - self.realurl = realurl - self.name = name - self.namespaces = namespaces - self.namespaces_local = namespaces_local - self.namespacealiases = namespacealiases - self.in_allowlist = in_allowlist - self.in_blocklist = in_blocklist - self.script = script - self.logo_url = logo_url + @classmethod + def assign(cls, api: str, headers: Dict[str, str], prefix: str = '', locale: str = default_locale): + return deepcopy(cls(api=api, headers=headers, prefix=prefix, locale=Locale(locale))) +@define +class WikiInfo: + api: str = '' + articlepath: str = '' + extensions: List[str] = [] + interwiki: Dict[str, str] = {} + realurl: str = '' + name: str = '' + namespaces: Dict[str, int] = {} + namespaces_local: Dict[str, str] = {} + namespacealiases: Dict[str, str] = {} + in_allowlist: bool = False + in_blocklist: bool = False + script: str = '' + logo_url: str = '' + + +@define class WikiStatus: - def __init__(self, - available: bool, - value: Union[WikiInfo, bool], - message: str): - self.available = available - self.value = value - self.message = message + available: bool + value: Union[WikiInfo, bool] + message: str +@define class PageInfo: - def __init__(self, - info: WikiInfo, - title: str, - id: int = -1, - before_title: str = None, - link: str = None, - edit_link: str = None, - file: str = None, - desc: str = None, - args: str = None, - selected_section: str = None, - sections: List[str] = None, - interwiki_prefix: str = '', - status: bool = True, - templates: List[str] = None, - before_page_property: str = 'page', - page_property: str = 'page', - has_template_doc: bool = False, - invalid_namespace: Union[str, bool] = False, - possible_research_title: List[str] = None, - body_class: List[str] = None - ): - self.info = info - self.id = id - self.title = title - self.before_title = before_title - self.link = link - self.edit_link = edit_link - self.file = file - self.desc = desc - self.args = args - self.selected_section = selected_section - self.sections = sections - self.interwiki_prefix = interwiki_prefix - self.templates = templates - self.status = status - self.before_page_property = before_page_property - self.page_property = page_property - self.has_template_doc = has_template_doc - self.invalid_namespace = invalid_namespace - self.possible_research_title = possible_research_title - self.invalid_section = False - self.body_class = body_class - self.is_talk_page = False - self.is_forum = False - self.forum_data = {} - self.is_forum_topic = False + info: WikiInfo + title: str + id: int = -1 + before_title: str = None + link: str = None + edit_link: str = None + file: str = None + desc: str = None + args: str = None + selected_section: str = None + sections: List[str] = None + interwiki_prefix: str = '' + status: bool = True + templates: List[str] = None + before_page_property: str = 'page' + page_property: str = 'page' + has_template_doc: bool = False + invalid_namespace: Union[str, bool] = False + possible_research_title: List[str] = None + body_class: List[str] = None + invalid_section: bool = False + is_talk_page: bool = False + is_forum: bool = False + is_forum_topic: bool = False + forum_data: dict = {} class WikiLib: @@ -370,9 +336,9 @@ async def get_html_to_text(self, page_name, section=None): h.ignore_tables = True h.single_line_break = True parse_text = get_parse['parse']['text']['*'] - if len(parse_text) > 65535: + t = h.handle(parse_text) + if len(t) > 65535: return self.locale.t("wiki.message.utils.wikilib.error.text_too_long") - t = h.handle(get_parse['parse']['text']['*']) if section: for i in range(1, 7): s = re.split(r'(.*' + '#' * i + r'[^#].*\[.*?])', t, re.M | re.S) @@ -566,7 +532,7 @@ async def parse_page_info(self, title: str = None, pageid: int = None, inline=Fa get_page = await self.get_json(**query_string) query = get_page.get('query') if not query: - return PageInfo(title=title, link=None, desc=self.locale.t("wiki.message.utils.wikilib.error.empty"), + return PageInfo(title=title, desc=self.locale.t("wiki.message.utils.wikilib.error.empty"), info=self.wiki_info) redirects_: List[Dict[str, str]] = query.get('redirects') diff --git a/modules/wiki/wiki.py b/modules/wiki/wiki.py index 5ce649eaef..4ca21e25a6 100644 --- a/modules/wiki/wiki.py +++ b/modules/wiki/wiki.py @@ -1,6 +1,6 @@ import asyncio import re -from typing import Union +from typing import Optional, Union import filetype @@ -15,7 +15,7 @@ from core.utils.image_table import image_table_render, ImageTable from core.utils.text import isint from .utils.dbutils import WikiTargetInfo -from .utils.mapping import generate_screenshot_v2_blocklist, special_namespace_list, random_title_list +from .utils.mapping import generate_screenshot_v2_blocklist from .utils.screenshot_image import generate_screenshot_v1, generate_screenshot_v2 from .utils.wikilib import WikiLib, PageInfo, InvalidWikiError, QueryInfo @@ -57,10 +57,17 @@ async def _(msg: Bot.MessageSession, pageid: str): await query_pages(msg, pageid=pageid, iw=iw, lang=lang) -async def query_pages(session: Union[Bot.MessageSession, QueryInfo], title: Union[str, list, tuple] = None, - pageid: str = None, iw: str = None, lang: str = None, - template=False, mediawiki=False, use_prefix=True, inline_mode=False, preset_message=None, - start_wiki_api=None): +async def query_pages(session: Union[Bot.MessageSession, QueryInfo], + title: Optional[Union[str, list, tuple]] = None, + pageid: Optional[str] = None, + iw: Optional[str] = None, + lang: Optional[str] = None, + preset_message: Optional[str] = None, + start_wiki_api: Optional[str] = None, + template: bool = False, + mediawiki: bool = False, + use_prefix: bool = True, + inline_mode: bool = False): if isinstance(session, MessageSession): target = WikiTargetInfo(session) start_wiki = target.get_start_wiki() @@ -148,24 +155,20 @@ async def query_pages(session: Union[Bot.MessageSession, QueryInfo], title: Unio try: tasks = [] for rd in ready_for_query_pages: - if rd.split(":")[0].lower() in special_namespace_list and rd.split(":")[1].lower() in random_title_list: - tasks.append(asyncio.create_task( - WikiLib(q, headers, locale=session.locale.locale).random_page())) - else: - if template: - rd = f'Template:{rd}' - if mediawiki: - rd = f'MediaWiki:{rd}' - tasks.append(asyncio.ensure_future( - WikiLib(q, headers, locale=session.locale.locale) - .parse_page_info(title=rd, inline=inline_mode, lang=lang))) + if template: + rd = f'Template:{rd}' + if mediawiki: + rd = f'MediaWiki:{rd}' + tasks.append(asyncio.ensure_future( + WikiLib(q, headers, locale=session.locale.locale) + .parse_page_info(title=rd, inline=inline_mode, lang=lang))) for rdp in ready_for_query_ids: tasks.append(asyncio.ensure_future( WikiLib(q, headers, locale=session.locale.locale) .parse_page_info(pageid=rdp, inline=inline_mode, lang=lang))) query = await asyncio.gather(*tasks) for result in query: - Logger.debug(result.__dict__) + Logger.debug(result) r: PageInfo = result display_title = None display_before_title = None @@ -259,12 +262,12 @@ async def _callback(msg: Bot.MessageSession): i_msg_lst.append(Plain(session.locale.t('wiki.message.invalid_section.select'))) i_msg_lst.append(Plain(session.locale.t('message.reply.prompt'))) - async def _callback(msg: Bot.MessageSession, forum_data=forum_data, r=r): - display = msg.as_display(text_only=True) - if isint(display) and int(display) <= len(forum_data) - 1: - await query_pages(session, title=forum_data[display]['text'], start_wiki_api=r.info.api) + async def _callback(msg: Bot.MessageSession): + display = msg.as_display(text_only=True) + if isint(display) and int(display) <= len(forum_data) - 1: + await query_pages(session, title=forum_data[display]['text'], start_wiki_api=r.info.api) - await session.send_message(i_msg_lst, callback=_callback) + await session.send_message(i_msg_lst, callback=_callback) else: plain_slice = [] diff --git a/modules/wikilog/__init__.py b/modules/wikilog/__init__.py index f13df98f44..79e281b4c8 100644 --- a/modules/wikilog/__init__.py +++ b/modules/wikilog/__init__.py @@ -2,7 +2,7 @@ import orjson as json -from core.builtins import Bot +from core.builtins import Bot, I18NContext from core.component import module from core.config import Config from core.constants import Info @@ -154,6 +154,8 @@ async def _(msg: Bot.MessageSession, apilink: str, logtype: str): @wikilog.handle('bot enable {{wikilog.help.bot.enable}}', required_superuser=True) @wikilog.handle('bot disable {{wikilog.help.bot.disable}}', required_superuser=True) +@wikilog.handle('keepalive enable {{wikilog.help.keepalive.enable}}', required_superuser=True) +@wikilog.handle('keepalive disable {{wikilog.help.keepalive.disable}}', required_superuser=True) async def _(msg: Bot.MessageSession, apilink: str): t = WikiLogUtil(msg) infos = json.loads(t.query.infos) @@ -161,7 +163,11 @@ async def _(msg: Bot.MessageSession, apilink: str): status = await wiki_info.check_wiki_available() if status.available: if status.value.api in infos: - if t.set_use_bot(status.value.api, 'enable' in msg.parsed_msg): + if 'keepalive' in msg.parsed_msg: + r = t.set_keep_alive(status.value.api, 'enable' in msg.parsed_msg) + else: + r = t.set_use_bot(status.value.api, 'enable' in msg.parsed_msg) + if r: await msg.finish(msg.locale.t('wikilog.message.config.wiki.success', wiki=status.value.name)) else: await msg.finish(msg.locale.t('wikilog.message.filter.set.failed')) @@ -238,3 +244,23 @@ async def _(fetch: Bot.FetchTarget, ctx: Bot.ModuleHookContext): for x in rc: await ft.send_direct_message(f'{wiki_info.name}\n{x}' if len(matched[id_]) > 1 else x) + + +@wikilog.hook('keepalive') +async def _(fetch: Bot.FetchTarget, ctx: Bot.ModuleHookContext): + data_ = WikiLogUtil.return_all_data() + for target in data_: + for wiki in data_[target]: + if 'keep_alive' in data_[target][wiki] and data_[target][wiki]['keep_alive']: + fetch_target = await fetch.fetch_target(target) + if fetch_target: + try: + wiki_ = WikiLib(wiki) + await wiki_.fixup_wiki_info() + get_user_info = await wiki_.get_json(action='query', meta='userinfo') + if n := get_user_info['query']['userinfo']['name']: + await fetch_target.send_direct_message(I18NContext('wikilog.message.keepalive.logged.as', name=n, + wiki=wiki_.wiki_info.name)) + except Exception as e: + Logger.error(f'Keep alive failed: {e}') + await fetch_target.send_direct_message(I18NContext('wikilog.message.keepalive.failed', link=wiki)) diff --git a/modules/wikilog/dbutils.py b/modules/wikilog/dbutils.py index 21c467078c..ce6b9f36f3 100644 --- a/modules/wikilog/dbutils.py +++ b/modules/wikilog/dbutils.py @@ -31,12 +31,11 @@ def conf_wiki(self, apilink: dict, add=False, reset=False): infos = json.loads(self.query.infos) if add or reset: if apilink not in infos or reset: - infos[apilink] = {'AbuseLog': {'enable': False, - 'filters': ['*']}, - 'RecentChanges': {'enable': False, - 'filters': ['*'], - 'rcshow': ['!bot']}, - 'use_bot': False} + infos[apilink] = {} + infos[apilink].setdefault('AbuseLog', {'enable': False, 'filters': ['*']}) + infos[apilink].setdefault('RecentChanges', {'enable': False, 'filters': ['*'], 'rcshow': ['!bot']}) + infos[apilink].setdefault('use_bot', False) + infos[apilink].setdefault('keep_alive', False) self.query.infos = json.dumps(infos) session.commit() session.expire_all() @@ -122,6 +121,26 @@ def get_use_bot(self, apilink: str): return infos[apilink]['use_bot'] return False + @retry(stop=stop_after_attempt(3), reraise=True) + @auto_rollback_error + def set_keep_alive(self, apilink: str, keep_alive: bool): # oh no it smells shit, will rewrite it in the future + infos = json.loads(self.query.infos) + if apilink in infos: + infos[apilink]['keep_alive'] = keep_alive + self.query.infos = json.dumps(infos) + session.commit() + session.expire_all() + return True + return False + + @retry(stop=stop_after_attempt(3), reraise=True) + @auto_rollback_error + def get_keep_alive(self, apilink: str): + infos = json.loads(self.query.infos) + if apilink in infos and 'keep_alive' in infos[apilink]: + return infos[apilink]['keep_alive'] + return False + @staticmethod def return_all_data(): all_data = session.query(WikiLogTargetSetInfo).all() diff --git a/modules/wikilog/locales/zh_cn.json b/modules/wikilog/locales/zh_cn.json index db43fb1b39..cc47a93115 100644 --- a/modules/wikilog/locales/zh_cn.json +++ b/modules/wikilog/locales/zh_cn.json @@ -36,5 +36,10 @@ "wikilog.help.rcshow.reset": "重置最近更改日志设置的筛选条件。", "wikilog.help.list": "列出所有已设置的内容。", "wikilog.help.api.get": "获取指定 Wiki 的指定 API 数据链接。", - "wikilog.message.untrust.wiki": "失败:此 Wiki 当前没有加入机器人的白名单列表中。" + "wikilog.help.keepalive.enable": "启用机器人登录状态检查。", + "wikilog.help.keepalive.disable": "禁用机器人登录状态检查。", + "wikilog.message.untrust.wiki": "失败:此 Wiki 当前没有加入机器人的白名单列表中。", + "wikilog.message.keepalive.logged.as": "[Keep alive] 以 ${name} 的身份登录:${wiki}", + "wikilog.message.keepalive.failed": "[Keep alive] 登录状态检查失败:${link}" + } diff --git a/modules/wordle/__init__.py b/modules/wordle/__init__.py index 8169b43f95..1419faae8f 100644 --- a/modules/wordle/__init__.py +++ b/modules/wordle/__init__.py @@ -124,7 +124,7 @@ def format_board(self): return '\n'.join(''.join(row) for row in formatted) def is_game_over(self): - return bool(len(self.board) != 0 and self.word == self.board[-1]) + return bool(len(self.board) != 0 and (self.word == self.board[-1]) or (len(self.board) > 5)) @staticmethod def from_random_word(): @@ -231,11 +231,7 @@ async def _(msg: Bot.MessageSession): await msg.send_message(start_msg) while board.get_trials() <= 6 and play_state.check() and not board.is_game_over(): - if not play_state.check(): - return wait = await msg.wait_next_message(timeout=None) - if not play_state.check(): - return word = wait.as_display(text_only=True).strip().lower() if len(word) != 5 or not (word.isalpha() and word.isascii()): continue diff --git a/poetry.lock b/poetry.lock index 9368d6d981..2176a95269 100644 --- a/poetry.lock +++ b/poetry.lock @@ -372,22 +372,22 @@ reference = "mirrors" [[package]] name = "attrs" -version = "23.2.0" +version = "24.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, - {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, ] [package.extras] -cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[tests]", "pre-commit"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] -tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] -tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [package.source] type = "legacy" @@ -2378,15 +2378,19 @@ name = "langconv" version = "0.3.0" description = "A Python library for conversion between Traditional and Simplified Chinese, inspired by MediaWiki's LanguageConverter." optional = false -python-versions = ">=3.9,<4.0" -files = [ - {file = "langconv-0.3.0-py3-none-any.whl", hash = "sha256:dfd3484e0373a07ed8271ab60293648e9d216ef460c58ff7dda80315292d0566"}, - {file = "langconv-0.3.0.tar.gz", hash = "sha256:816bedf81db368a410959293a31aeebe4cd75de516427b50370727003f3bd3ce"}, -] +python-versions = ">=3.9" +files = [] +develop = false [package.dependencies] -attrs = ">=23.2.0,<24.0.0" -iso639-lang = ">=2.2.3,<3.0.0" +attrs = "^24.2.0" +iso639-lang = "^2.2.3" + +[package.source] +type = "git" +url = "https://github.com/OasisAkari/langconv.py.git" +reference = "HEAD" +resolved_reference = "975ee0896096b63148f0930db762607a5a2113df" [package.source] type = "legacy" @@ -5317,4 +5321,4 @@ reference = "mirrors" [metadata] lock-version = "2.0" python-versions = "^3.12.0" -content-hash = "da8ebfb92146be8808b96bfbf4e259908ebfdd98702bac00070595d29010fda3" +content-hash = "78bd3610a1b3fc1005ea68cf7bf4f181df6a14bc9cfacef8ede379835fc559c5" diff --git a/pyproject.toml b/pyproject.toml index c1cb892832..7584ac6acf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ tiktoken = "^0.8.0" pycryptodome = "^3.18.0" khl-py = "^0.3.16" matrix-nio = "^0.25.0" -attrs = "^23.1.0" +attrs = "^24.2.0" uvicorn = {extras = ["standard"], version = "^0.32.0"} pyjwt = {extras = ["crypto"], version = "^2.8.0"} python-whois = "^0.9.0" @@ -62,7 +62,7 @@ fastapi = "^0.115.0" inputimeout = "^1.0.4" prompt-toolkit = "^3.0.47" emoji = "^2.12.1" -langconv = "0.3.0" +langconv = {git = "https://github.com/OasisAkari/langconv.py.git"} ff3 = "^1.0.2" orjson = "^3.10.9" jinja2 = "^3.1.4" diff --git a/requirements.txt b/requirements.txt index 73262db4a4..7eec73aea1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ anyio==4.6.2.post1 ; python_full_version >= "3.12.0" and python_version < "4.0" appdirs==1.4.4 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" apscheduler==3.10.4 ; python_version >= "3.12" and python_version < "4.0" asyncio-dgram==2.2.0 ; python_full_version >= "3.12.0" and python_version < "4" -attrs==23.2.0 ; python_version >= "3.12" and python_version < "4.0" +attrs==24.2.0 ; python_version >= "3.12" and python_version < "4.0" backoff==2.2.1 ; python_full_version >= "3.12.0" and python_version < "4.0" beautifulsoup4==4.12.3 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" blinker==1.8.2 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" @@ -62,7 +62,7 @@ hyperframe==6.0.1 ; python_full_version >= "3.12.0" and python_full_version < "4 identify==2.6.1 ; python_version >= "3.12" and python_version < "4.0" idna==3.10 ; python_version >= "3.12" and python_version < "4.0" inputimeout==1.0.4 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" -iso639-lang==2.5.0 ; python_full_version >= "3.12.0" and python_version < "4.0" +iso639-lang==2.5.0 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" itsdangerous==2.2.0 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" jaraco-context==6.0.1 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" jinja2==3.1.4 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" @@ -76,8 +76,8 @@ kiwisolver==1.4.7 ; python_full_version >= "3.12.0" and python_full_version < "4 langchain-core==0.3.15 ; python_full_version >= "3.12.0" and python_version < "4.0" langchain-text-splitters==0.3.2 ; python_full_version >= "3.12.0" and python_version < "4.0" langchain==0.3.7 ; python_full_version >= "3.12.0" and python_version < "4.0" -langconv==0.3.0 ; python_full_version >= "3.12.0" and python_version < "4.0" -langsmith==0.1.140 ; python_full_version >= "3.12.0" and python_version < "4.0" +langconv @ git+https://github.com/OasisAkari/langconv.py.git@975ee0896096b63148f0930db762607a5a2113df ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" +langsmith==0.1.139 ; python_full_version >= "3.12.0" and python_version < "4.0" loguru==0.7.2 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" magic-filter==1.0.12 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" markupsafe==3.0.2 ; python_full_version >= "3.12.0" and python_full_version < "4.0.0" diff --git a/schedulers/wiki_bot.py b/schedulers/wiki_bot.py index 361c66c703..6c44dcece4 100644 --- a/schedulers/wiki_bot.py +++ b/schedulers/wiki_bot.py @@ -12,4 +12,5 @@ async def login_bots(): Logger.info('Start login wiki bot account...') await BotAccount.login() await JobQueue.trigger_hook_all('wiki_bot.login_wiki_bots', cookies=BotAccount.cookies) + await JobQueue.trigger_hook_all('wikilog.keepalive') Logger.info('Successfully login wiki bot account.')