Skip to content

Commit

Permalink
chore: clean codes
Browse files Browse the repository at this point in the history
  • Loading branch information
Soulter committed Jun 4, 2024
1 parent 045c415 commit b943c62
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 81 deletions.
101 changes: 21 additions & 80 deletions astrbot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,13 @@
# 百度内容审核实例
baidu_judge = None

# CLI
PLATFORM_CLI = 'cli'

# 全局对象
_global_object: GlobalObject = None


def privider_chooser(cfg):
l = []
if 'openai' in cfg and len(cfg['openai']['key']) > 0 and cfg['openai']['key'][0] is not None:
if 'openai' in cfg and len(cfg['openai']['key']) and cfg['openai']['key'][0]:
l.append('openai_official')
return l

Expand Down Expand Up @@ -157,7 +154,7 @@ def init():
logger.info("独立会话配置错误: "+str(e))

nick_qq = cc.get("nick_qq", None)
if nick_qq == None:
if not nick_qq:
nick_qq = ("ai", "!", "!")
if isinstance(nick_qq, str):
nick_qq = (nick_qq,)
Expand All @@ -177,28 +174,24 @@ def init():
logger.info(
f"成功载入 {len(_global_object.cached_plugins)} 个插件")
else:
logger.info(err)
logger.error(err)

if chosen_provider is None:
llm_command_instance[NONE_LLM] = _command
chosen_provider = NONE_LLM

logger.info("正在载入机器人消息平台")
# logger.info("提示:需要添加管理员 ID 才能使用 update/plugin 等指令),可在可视化面板添加。(如已添加可忽略)")
platform_str = ""
# GOCQ
if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
logger.info("启用 QQ_GOCQ 机器人消息平台")
threading.Thread(target=run_gocq_bot, args=(
cfg, _global_object), daemon=True).start()
platform_str += "QQ_GOCQ,"

# QQ频道
if 'qqbot' in cfg and cfg['qqbot']['enable'] and cfg['qqbot']['appid'] != None:
logger.info("启用 QQ_OFFICIAL 机器人消息平台")
threading.Thread(target=run_qqchan_bot, args=(
cfg, _global_object), daemon=True).start()
platform_str += "QQ_OFFICIAL,"

# 初始化dashboard
_global_object.dashboard_data = DashBoardData(
Expand All @@ -219,19 +212,15 @@ def init():
logger.info(
"如果有任何问题, 请在 https://github.com/Soulter/AstrBot 上提交 issue 或加群 322154837。")
logger.info("请给 https://github.com/Soulter/AstrBot 点个 star。")
if platform_str == '':
platform_str = "(未启动任何平台,请前往面板添加)"
logger.info(f"🎉 项目启动完成")

dashboard_thread.join()


'''
运行 QQ_OFFICIAL 机器人
'''


def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
'''
运行 QQ_OFFICIAL 机器人
'''
try:
from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(
Expand All @@ -244,14 +233,11 @@ def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
logger.error(r"如果您是初次启动,请前往可视化面板填写配置。详情请看:https://astrbot.soulter.top/center/。")


'''
运行 QQ_GOCQ 机器人
'''


def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
'''
运行 QQ_GOCQ 机器人
'''
from model.platform.qq_gocq import QQGOCQ

noticed = False
host = cc.get("gocq_host", "127.0.0.1")
port = cc.get("gocq_websocket_port", 6700)
Expand All @@ -278,12 +264,10 @@ def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
input("启动QQ机器人出现错误"+str(e))


'''
检查发言频率
'''


def check_frequency(id) -> bool:
'''
检查发言频率
'''
ts = int(time.time())
if id in user_frequency:
if ts-user_frequency[id]['time'] > frequency_time:
Expand Down Expand Up @@ -324,11 +308,10 @@ async def oper_msg(message: AstrBotMessage,
platform: str 所注册的平台的名称。如果没有注册,将抛出一个异常。
"""
global chosen_provider, _global_object
message_str = ''
session_id = session_id
role = role
message_str = message.message_str
hit = False # 是否命中指令
command_result = () # 调用指令返回的结果
llm_result_str = ""

# 获取平台实例
reg_platform: RegisteredPlatform = None
Expand All @@ -342,35 +325,13 @@ async def oper_msg(message: AstrBotMessage,
# 统计数据,如频道消息量
await record_message(platform, session_id)

for i in message.message:
if isinstance(i, Plain):
message_str += i.text.strip()
if message_str == "":
if not message_str:
return MessageResult("Hi~")

# 检查发言频率
if not check_frequency(message.sender.user_id):
return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置{frequency_time}秒内只能提问{frequency_count}次。')

# 检查是否是更换语言模型的请求
temp_switch = ""
if message_str.startswith('/gpt'):
target = chosen_provider
if message_str.startswith('/gpt'):
target = OPENAI_OFFICIAL
l = message_str.split(' ')
if len(l) > 1 and l[1] != "":
# 临时对话模式,先记录下之前的语言模型,回答完毕后再切回
temp_switch = chosen_provider
chosen_provider = target
message_str = l[1]
else:
chosen_provider = target
cc.put("chosen_provider", chosen_provider)
return MessageResult(f"已切换至【{chosen_provider}】")

llm_result_str = ""

# check commands and plugins
message_str_no_wake_prefix = message_str
for wake_prefix in _global_object.nick: # nick: tuple
Expand Down Expand Up @@ -400,7 +361,7 @@ async def oper_msg(message: AstrBotMessage,
logger.info("一条消息由于 Bot 未启动任何语言模型并且未触发指令而将被忽略。")
return
try:
if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix):
if llm_wake_prefix and not message_str.startswith(llm_wake_prefix):
return
# check image url
image_url = None
Expand All @@ -418,7 +379,7 @@ async def oper_msg(message: AstrBotMessage,
message_str = message_str[3:]
web_sch_flag = True
else:
message_str += " " + cc.get("llm_env_prompt", "")
message_str += "\n" + cc.get("llm_env_prompt", "")
if chosen_provider == OPENAI_OFFICIAL:
if _global_object.web_search or web_sch_flag:
official_fc = chosen_provider == OPENAI_OFFICIAL
Expand All @@ -431,32 +392,15 @@ async def oper_msg(message: AstrBotMessage,
logger.error(f"调用异常:{traceback.format_exc()}")
return MessageResult(f"调用异常。详细原因:{str(e)}")

# 切换回原来的语言模型
if temp_switch != "":
chosen_provider = temp_switch

if hit:
# 有指令或者插件触发
# command_result 是一个元组:(指令调用是否成功, 指令返回的文本结果, 指令类型)
if command_result == None:
if not command_result:
return
command = command_result[2]

if not command_result[0]:
return MessageResult(f"指令调用错误: \n{str(command_result[1])}")

# 画图指令
if command == 'draw':
# 保存到本地
path = await gu.download_image_by_url(command_result[1])
return MessageResult([Image.fromFileSystem(path)])
# 其他指令
else:
try:
return MessageResult(command_result[1])
except BaseException as e:
return MessageResult(f"回复消息出错: {str(e)}")
return
if isinstance(command_result[1], (list, str)):
return MessageResult(command_result[1])

# 敏感过滤
# 过滤不合适的词
Expand All @@ -468,7 +412,4 @@ async def oper_msg(message: AstrBotMessage,
if not check:
return MessageResult(f"你的提问得到的回复【百度内容审核】未通过,不予回复。\n\n{msg}")
# 发送信息
try:
return MessageResult(llm_result_str)
except BaseException as e:
logger.info("回复消息错误: \n"+str(e))
return MessageResult(llm_result_str)
6 changes: 5 additions & 1 deletion model/command/openai_official.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from model.command.command import Command
from model.provider.openai_official import ProviderOpenAIOfficial, MODELS
from util.personality import personalities
from util.general_utils import download_image_by_url
from type.types import GlobalObject
from type.command import CommandItem
from SparkleLogging.utils.core import LogManager
from logging import Logger
from openai._exceptions import NotFoundError
from nakuru.entities.components import Image

logger: Logger = LogManager.GetLogger(log_name='astrbot-core')

Expand Down Expand Up @@ -248,4 +250,6 @@ async def draw(self, message: str):
return False, "未启用 OpenAI 官方 API", "draw"
message = message.removeprefix("/").removeprefix("画")
img_url = await self.provider.image_generate(message)
return True, img_url, "draw"
p = await download_image_by_url(url=img_url)
with open(p, 'rb') as f:
return True, [Image.fromBytes(f.read())], "draw"

0 comments on commit b943c62

Please sign in to comment.