From 736f8b613ce5cff49d1e7a937ffcacaef53ba8b9 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 15 Dec 2024 17:05:56 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=BA=20ollama=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E8=A7=86=E8=A7=89=E5=92=8C=E5=87=BD=E6=95=B0=E8=B0=83=E7=94=A8?= =?UTF-8?q?=20(#950)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../modelmgr/requesters/ollamachat.py | 42 ++++++++++++++++--- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index f66a2a78..d7ec9614 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -4,6 +4,8 @@ import os import typing from typing import Union, Mapping, Any, AsyncIterator +import uuid +import json import async_lru import ollama @@ -60,21 +62,49 @@ async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelI image_urls.append(image_url) msg["content"] = "\n".join(text_content) msg["images"] = [url.split(',')[1] for url in image_urls] + if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict + for tool_call in msg['tool_calls']: + tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) args["messages"] = messages - resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args) + args["tools"] = [] + if user_funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(user_funcs) + if tools: + args["tools"] = tools + + resp = await self._req(args) message: llm_entities.Message = await self._make_msg(resp) return message async def _make_msg( self, - chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message: - message: Any = chat_completions.pop('message', None) + chat_completions: ollama.ChatResponse) -> llm_entities.Message: + message: ollama.Message = chat_completions.message if message is None: raise ValueError("chat_completions must contain a 'message' field") - message.update(chat_completions) - ret_msg: llm_entities.Message = llm_entities.Message(**message) + ret_msg: llm_entities.Message = None + + if message.content is not None: + ret_msg = llm_entities.Message( + role="assistant", + content=message.content + ) + if message.tool_calls is not None and len(message.tool_calls) > 0: + tool_calls: list[llm_entities.ToolCall] = [] + + for tool_call in message.tool_calls: + tool_calls.append(llm_entities.ToolCall( + id=uuid.uuid4().hex, + type="function", + function=llm_entities.FunctionCall( + name=tool_call.function.name, + arguments=json.dumps(tool_call.function.arguments) + ) + )) + ret_msg.tool_calls = tool_calls + return ret_msg async def call( @@ -92,7 +122,7 @@ async def call( msg_dict["content"] = "\n".join(part["text"] for part in content) req_messages.append(msg_dict) try: - return await self._closure(req_messages, model) + return await self._closure(req_messages, model, funcs) except asyncio.TimeoutError: raise errors.RequesterError('请求超时')