From 558ab04bf0f7b6bbab94076e757a5792e9969a68 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 16 Aug 2024 12:22:26 +0800 Subject: [PATCH] refactor: update model params --- .../impl/ollama_model_provider/model/llm.py | 4 ++++ .../impl/openai_model_provider/model/llm.py | 5 +++++ .../impl/qwen_model_provider/credential/llm.py | 9 +++++++++ .../impl/qwen_model_provider/model/llm.py | 8 ++++++-- .../impl/tencent_model_provider/model/llm.py | 5 ++++- .../impl/volcanic_engine_model_provider/model/llm.py | 4 ++++ 6 files changed, 32 insertions(+), 3 deletions(-) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index 9ae88558b2..d70f0de111 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -25,6 +25,10 @@ def get_base_url(url: str): class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): api_base = model_credential.get('api_base', '') diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index 0708803b93..d73065e86d 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -13,6 +13,11 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = {} diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py index 714a8eaa45..d8c9b85e0e 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py @@ -57,4 +57,13 @@ def get_other_fields(self, model_name): 'precision': 2, 'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定' }, + 'max_tokens': { + 'value': 800, + 'min': 1, + 'max': 2048, + 'step': 1, + 'label': '输出最大Tokens', + 'precision': 0, + 'tooltip': '指定模型可生成的最大token个数' + } } diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py index 2bf7f6ac95..a1af92c621 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -19,6 +19,10 @@ class QwenChatModel(MaxKBBaseModel, ChatTongyi): + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = {} @@ -29,7 +33,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** chat_tong_yi = QwenChatModel( model_name=model_name, dashscope_api_key=model_credential.get('api_key'), - **optional_params, + model_kwargs=optional_params, ) return chat_tong_yi @@ -61,7 +65,7 @@ def _stream( if ( choice["finish_reason"] == "stop" and message["content"] == "" - ): + ) or (choice["finish_reason"] == "length"): token_usage = stream_resp["usage"] self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) if ( diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py index dd726b214a..0f879f73b5 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py @@ -9,6 +9,9 @@ class TencentModel(MaxKBBaseModel, ChatHunyuan): + @staticmethod + def is_cache_model(): + return False def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs): hunyuan_app_id = credentials.get('hunyuan_app_id') @@ -25,7 +28,7 @@ def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id, hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, - temperature=optional_params.get('temperature', None) + temperature=optional_params.get('temperature', 1.0) ) @staticmethod diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py index e7ce56b6da..c549710e5d 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -6,6 +6,10 @@ class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI): + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = {}