Skip to content

Commit

Permalink
fix: 解决非流式返回报错的缺陷
Browse files Browse the repository at this point in the history
  • Loading branch information
wxg0103 committed Sep 19, 2024
1 parent 90a7a9d commit 88f6e33
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
31 changes: 29 additions & 2 deletions apps/setting/models_provider/impl/base_chat_open_ai.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# coding=utf-8

from typing import List, Dict, Optional, Any, Iterator, Type
from typing import List, Dict, Optional, Any, Iterator, Type, cast
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
from langchain_core.runnables import RunnableConfig, ensure_config
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _convert_delta_to_message_chunk

Expand Down Expand Up @@ -76,3 +78,28 @@ def _stream(
)
is_first_chunk = False
yield generation_chunk

def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = ensure_config(config)
chat_result = cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
).generations[0][0],
).message
self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage'])
return chat_result
31 changes: 29 additions & 2 deletions apps/setting/models_provider/impl/qwen_model_provider/model/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
@date:2024/4/28 11:44
@desc:
"""
from typing import List, Dict, Optional, Iterator, Any
from typing import List, Dict, Optional, Iterator, Any, cast

from langchain_community.chat_models import ChatTongyi
from langchain_community.llms.tongyi import generate_with_last_element_mark
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
from langchain_core.runnables import RunnableConfig, ensure_config

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
Expand Down Expand Up @@ -83,3 +85,28 @@ def _stream(
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> BaseMessage:
config = ensure_config(config)
chat_result = cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
run_id=config.pop("run_id", None),
**kwargs,
).generations[0][0],
).message
self.__dict__.setdefault('_last_generation_info', {}).update(chat_result.response_metadata['token_usage'])
return chat_result

0 comments on commit 88f6e33

Please sign in to comment.