From 88f6e336e7279e70cc7232f6a080455928a6b3a3 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Thu, 19 Sep 2024 11:41:29 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=A7=A3=E5=86=B3=E9=9D=9E=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E8=BF=94=E5=9B=9E=E6=8A=A5=E9=94=99=E7=9A=84=E7=BC=BA?= =?UTF-8?q?=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../models_provider/impl/base_chat_open_ai.py | 31 +++++++++++++++++-- .../impl/qwen_model_provider/model/llm.py | 31 +++++++++++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/apps/setting/models_provider/impl/base_chat_open_ai.py b/apps/setting/models_provider/impl/base_chat_open_ai.py index b0d5f3bbcd..e7f4990f1f 100644 --- a/apps/setting/models_provider/impl/base_chat_open_ai.py +++ b/apps/setting/models_provider/impl/base_chat_open_ai.py @@ -1,9 +1,11 @@ # coding=utf-8 -from typing import List, Dict, Optional, Any, Iterator, Type +from typing import List, Dict, Optional, Any, Iterator, Type, cast from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LanguageModelInput from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk -from langchain_core.outputs import ChatGenerationChunk +from langchain_core.outputs import ChatGenerationChunk, ChatGeneration +from langchain_core.runnables import RunnableConfig, ensure_config from langchain_openai import ChatOpenAI from langchain_openai.chat_models.base import _convert_delta_to_message_chunk @@ -76,3 +78,28 @@ def _stream( ) is_first_chunk = False yield generation_chunk + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + config = ensure_config(config) + chat_result = cast( + ChatGeneration, + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + run_id=config.pop("run_id", None), + **kwargs, + ).generations[0][0], + ).message + self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage']) + return chat_result diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index a1af92c621..71e5ed8f72 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -6,13 +6,15 @@ @dateļ¼š2024/4/28 11:44 @desc: """ -from typing import List, Dict, Optional, Iterator, Any +from typing import List, Dict, Optional, Iterator, Any, cast from langchain_community.chat_models import ChatTongyi from langchain_community.llms.tongyi import generate_with_last_element_mark from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LanguageModelInput from langchain_core.messages import BaseMessage, get_buffer_string -from langchain_core.outputs import ChatGenerationChunk +from langchain_core.outputs import ChatGenerationChunk, ChatGeneration +from langchain_core.runnables import RunnableConfig, ensure_config from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel @@ -83,3 +85,28 @@ def _stream( if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + config = ensure_config(config) + chat_result = cast( + ChatGeneration, + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + run_id=config.pop("run_id", None), + **kwargs, + ).generations[0][0], + ).message + self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage']) + return chat_result