Skip to content

Commit

Permalink
feat: 重构图片消息传递逻辑 (#957, #955)
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ committed Dec 24, 2024
1 parent 535c4a8 commit 12cfce3
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 54 deletions.
4 changes: 2 additions & 2 deletions pkg/pipeline/preproc/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ async def process(
)
elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported):
if me.url is not None:
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_url(str(me.url))
llm_entities.ContentElement.from_image_base64(me.base64)
)

query.user_message = llm_entities.Message(
Expand Down
28 changes: 15 additions & 13 deletions pkg/platform/sources/aiocqhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
import datetime

import aiocqhttp
import aiohttp

from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities

from ...utils import image

class AiocqhttpMessageConverter(adapter.MessageConverter):

@staticmethod
def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
async def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
msg_list = aiocqhttp.Message()

msg_id = 0
Expand Down Expand Up @@ -59,15 +60,15 @@ def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[li
elif type(msg) is forward.Forward:

for node in msg.node_list:
msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
msg_list.extend(await AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])

else:
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))

return msg_list, msg_id, msg_time

@staticmethod
def target2yiri(message: str, message_id: int = -1):
async def target2yiri(message: str, message_id: int = -1):
message = aiocqhttp.Message(message)

yiri_msg_list = []
Expand All @@ -89,7 +90,8 @@ def target2yiri(message: str, message_id: int = -1):
elif msg.type == "text":
yiri_msg_list.append(platform_message.Plain(text=msg.data["text"]))
elif msg.type == "image":
yiri_msg_list.append(platform_message.Image(url=msg.data["url"]))
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}"))

chain = platform_message.MessageChain(yiri_msg_list)

Expand All @@ -99,9 +101,9 @@ def target2yiri(message: str, message_id: int = -1):
class AiocqhttpEventConverter(adapter.EventConverter):

@staticmethod
def yiri2target(event: platform_events.Event, bot_account_id: int):
async def yiri2target(event: platform_events.Event, bot_account_id: int):

msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain)
msg, msg_id, msg_time = await AiocqhttpMessageConverter.yiri2target(event.message_chain)

if type(event) is platform_events.GroupMessage:
role = "member"
Expand Down Expand Up @@ -164,8 +166,8 @@ def yiri2target(event: platform_events.Event, bot_account_id: int):
return aiocqhttp.Event.from_payload(payload)

@staticmethod
def target2yiri(event: aiocqhttp.Event):
yiri_chain = AiocqhttpMessageConverter.target2yiri(
async def target2yiri(event: aiocqhttp.Event):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(
event.message, event.message_id
)

Expand Down Expand Up @@ -242,7 +244,7 @@ async def shutdown_trigger_placeholder():
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
aiocq_msg = await AiocqhttpMessageConverter.yiri2target(message)[0]

if target_type == "group":
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
Expand All @@ -255,8 +257,8 @@ async def reply_message(
message: platform_message.MessageChain,
quote_origin: bool = False,
):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if quote_origin:
aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg

Expand All @@ -276,7 +278,7 @@ def register_listener(
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
try:
return await callback(self.event_converter.target2yiri(event), self)
return await callback(await self.event_converter.target2yiri(event), self)
except:
traceback.print_exc()

Expand Down
6 changes: 6 additions & 0 deletions pkg/provider/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ContentElement(pydantic.BaseModel):

image_url: typing.Optional[ImageURLContentObject] = None

image_base64: typing.Optional[str] = None

def __str__(self):
if self.type == 'text':
return self.text
Expand All @@ -53,6 +55,10 @@ def from_text(cls, text: str):
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))

@classmethod
def from_image_base64(cls, image_base64: str):
return cls(type='image_base64', image_base64=image_base64)


class Message(pydantic.BaseModel):
Expand Down
1 change: 1 addition & 0 deletions pkg/provider/modelmgr/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ async def preprocess(
@abc.abstractmethod
async def call(
self,
query: core_entities.Query,
model: modelmgr_entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
Expand Down
18 changes: 9 additions & 9 deletions pkg/provider/modelmgr/requesters/anthropicmsgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing
import traceback
import base64

import anthropic
import httpx
Expand Down Expand Up @@ -39,6 +40,7 @@ async def initialize(self):

async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
Expand Down Expand Up @@ -70,28 +72,26 @@ async def call(
if isinstance(m.content, str) and m.content.strip() != "":
req_messages.append(m.dict(exclude_none=True))
elif isinstance(m.content, list):
# m.content = [
# c for c in m.content if c.type == "text"
# ]

# if len(m.content) > 0:
# req_messages.append(m.dict(exclude_none=True))

msg_dict = m.dict(exclude_none=True)

for i, ce in enumerate(m.content):
if ce.type == "image_url":
base64_image, image_format = await image.qq_image_url_to_base64(ce.image_url.url)

if ce.type == "image_base64":
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)

alter_image_ele = {
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{image_format}",
"data": base64_image
"data": image_b64
}
}
msg_dict["content"][i] = alter_image_ele

print(msg_dict)

req_messages.append(msg_dict)

args["messages"] = req_messages
Expand Down
20 changes: 9 additions & 11 deletions pkg/provider/modelmgr/requesters/chatcmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def _make_msg(

async def _closure(
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
Expand All @@ -87,8 +88,12 @@ async def _closure(
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_url":
me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])
if me["type"] == "image_base64":
me["image_url"] = {
"url": me["image_base64"]
}
me["type"] = "image_url"
del me["image_base64"]

args["messages"] = messages

Expand All @@ -102,6 +107,7 @@ async def _closure(

async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
Expand All @@ -118,7 +124,7 @@ async def call(
req_messages.append(msg_dict)

try:
return await self._closure(req_messages, model, funcs)
return await self._closure(query, req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
Expand All @@ -134,11 +140,3 @@ async def call(
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')

@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image, image_format = await image.qq_image_url_to_base64(original_url)
return f"data:image/{image_format};base64,{base64_image}"
22 changes: 8 additions & 14 deletions pkg/provider/modelmgr/requesters/ollamachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from typing import Union, Mapping, Any, AsyncIterator
import uuid
import json
import base64

import async_lru
import ollama

from .. import entities, errors, requester
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....core import app
from ....core import app, entities as core_entities
from ....utils import image

REQUESTER_NAME: str = "ollama-chat"
Expand Down Expand Up @@ -43,7 +44,7 @@ async def _req(self,
**args
)

async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo,
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
user_funcs: list[tools_entities.LLMFunction] = None) -> (
llm_entities.Message):
args: Any = self.request_cfg['args'].copy()
Expand All @@ -57,9 +58,9 @@ async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelI
for me in msg["content"]:
if me["type"] == "text":
text_content.append(me["text"])
elif me["type"] == "image_url":
image_url = await self.get_base64_str(me["image_url"]['url'])
image_urls.append(image_url)
elif me["type"] == "image_base64":
image_urls.append(me["image_base64"])

msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(',')[1] for url in image_urls]
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
Expand Down Expand Up @@ -109,6 +110,7 @@ async def _make_msg(

async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
Expand All @@ -122,14 +124,6 @@ async def call(
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(req_messages, model, funcs)
return await self._closure(query, req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')

@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image, image_format = await image.qq_image_url_to_base64(original_url)
return f"data:image/{image_format};base64,{base64_image}"
4 changes: 2 additions & 2 deletions pkg/provider/runners/localagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_ent
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]

# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)

yield msg

Expand Down Expand Up @@ -61,7 +61,7 @@ async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_ent
req_messages.append(err_msg)

# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)

yield msg

Expand Down
21 changes: 18 additions & 3 deletions pkg/utils/image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import typing
import io
from urllib.parse import urlparse, parse_qs
import ssl

import aiohttp
import PIL.Image


def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
Expand All @@ -14,7 +16,7 @@ def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:


async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]:
"""获取QQ图片的bytes"""
"""[弃用]获取QQ图片的bytes"""
image_url, query = get_qq_image_downloadable_url(image_url)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
Expand All @@ -24,8 +26,11 @@ async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]:
resp.raise_for_status()
file_bytes = await resp.read()
content_type = resp.headers.get('Content-Type')
if not content_type or not content_type.startswith('image/'):
if not content_type:
image_format = 'jpeg'
elif not content_type.startswith('image/'):
pil_img = PIL.Image.open(io.BytesIO(file_bytes))
image_format = pil_img.format.lower()
else:
image_format = content_type.split('/')[-1]
return file_bytes, image_format
Expand All @@ -34,7 +39,7 @@ async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]:
async def qq_image_url_to_base64(
image_url: str
) -> typing.Tuple[str, str]:
"""将QQ图片URL转为base64,并返回图片格式
"""[弃用]将QQ图片URL转为base64,并返回图片格式
Args:
image_url (str): QQ图片URL
Expand All @@ -52,3 +57,13 @@ async def qq_image_url_to_base64(
base64_str = base64.b64encode(file_bytes).decode()

return base64_str, image_format

async def extract_b64_and_format(image_base64_data: str) -> typing.Tuple[str, str]:
"""提取base64编码和图片格式

提取出base64编码和图片格式
"""
base64_str = image_base64_data.split(',')[-1]
image_format = image_base64_data.split(':')[-1].split(';')[0].split('/')[-1]
return base64_str, image_format

0 comments on commit 12cfce3

Please sign in to comment.