Skip to content

Commit

Permalink
refactor: update model params
Browse files Browse the repository at this point in the history
  • Loading branch information
wxg0103 committed Aug 16, 2024
1 parent b596b69 commit 558ab04
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def get_base_url(url: str):


class OllamaChatModel(MaxKBBaseModel, ChatOpenAI):
@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_base = model_credential.get('api_base', '')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@


class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):

@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,13 @@ def get_other_fields(self, model_name):
'precision': 2,
'tooltip': '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'
},
'max_tokens': {
'value': 800,
'min': 1,
'max': 2048,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@


class QwenChatModel(MaxKBBaseModel, ChatTongyi):
@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
Expand All @@ -29,7 +33,7 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], **
chat_tong_yi = QwenChatModel(
model_name=model_name,
dashscope_api_key=model_credential.get('api_key'),
**optional_params,
model_kwargs=optional_params,
)
return chat_tong_yi

Expand Down Expand Up @@ -61,7 +65,7 @@ def _stream(
if (
choice["finish_reason"] == "stop"
and message["content"] == ""
):
) or (choice["finish_reason"] == "length"):
token_usage = stream_resp["usage"]
self.__dict__.setdefault('_last_generation_info', {}).update(token_usage)
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@


class TencentModel(MaxKBBaseModel, ChatHunyuan):
@staticmethod
def is_cache_model():
return False

def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs):
hunyuan_app_id = credentials.get('hunyuan_app_id')
Expand All @@ -25,7 +28,7 @@ def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool

super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id,
hunyuan_secret_key=hunyuan_secret_key, streaming=streaming,
temperature=optional_params.get('temperature', None)
temperature=optional_params.get('temperature', 1.0)
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@


class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
Expand Down

0 comments on commit 558ab04

Please sign in to comment.