Skip to content

Commit

Permalink
Merge branch 'main' into pr@main@function_lib
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 authored Aug 15, 2024
2 parents 561b26b + 845225c commit d57f857
Show file tree
Hide file tree
Showing 66 changed files with 2,628 additions and 521 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/create-pr-from-push.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
on:
push:
branches:
- 'pr@**'
- 'repr@**'

name: 针对特定分支名自动创建 PR

jobs:
generic_handler:
name: 自动创建 PR
runs-on: ubuntu-latest
steps:
- name: Create pull request
uses: jumpserver/action-generic-handler@master
env:
GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
3 changes: 3 additions & 0 deletions apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class InstanceSerializer(serializers.Serializer):
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))

user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
temperature = serializers.FloatField(required=False, allow_null=True, error_messages=ErrMessage.float("温度"))
max_tokens = serializers.IntegerField(required=False, allow_null=True,
error_messages=ErrMessage.integer("最大token数"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def execute(self, message_list: List[BaseMessage],
client_id=None, client_type=None,
no_references_setting=None,
**kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
chat_model = get_model_instance_by_model_user_id(model_id, user_id, **kwargs) if model_id is not None else None
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
Expand Down
35 changes: 35 additions & 0 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from embedding.models import SearchMode
from setting.models import AuthOperate
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
from django.conf import settings
Expand Down Expand Up @@ -109,6 +110,10 @@ class DatasetSettingSerializer(serializers.Serializer):

class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
temperature = serializers.FloatField(required=False, allow_null=True,
error_messages=ErrMessage.char("温度"))
max_tokens = serializers.IntegerField(required=False, allow_null=True,
error_messages=ErrMessage.integer("最大token数"))


class ApplicationWorkflowSerializer(serializers.Serializer):
Expand Down Expand Up @@ -541,6 +546,7 @@ def edit(self, with_valid=True):
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("模型id"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
Expand Down Expand Up @@ -722,6 +728,35 @@ def list_dataset(self, with_valid=True):
[self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None,
application.user_id, self.data.get('user_id')])

def get_other_file_list(self):
temperature = None
max_tokens = None
application = Application.objects.filter(id=self.initial_data.get("application_id")).first()
if application:
setting_dict = application.model_setting
temperature = setting_dict.get("temperature")
max_tokens = setting_dict.get("max_tokens")
model = Model.objects.filter(id=self.initial_data.get("model_id")).first()
if model:
res = ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).get_other_fields(
model.model_name)
if temperature and res.get('temperature'):
res['temperature']['value'] = temperature
if max_tokens and res.get('max_tokens'):
res['max_tokens']['value'] = max_tokens
return res

def save_other_config(self, data):
application = Application.objects.filter(id=self.initial_data.get("application_id")).first()
if application:
setting_dict = application.model_setting
for key in ['max_tokens', 'temperature']:
if key in data:
setting_dict[key] = data[key]
application.model_setting = setting_dict
application.save()

class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:
model = ApplicationApiKey
Expand Down
2 changes: 2 additions & 0 deletions apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def to_base_pipeline_manage_params(self):
'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'temperature': model_setting.get('temperature') if 'temperature' in model_setting else None,
'max_tokens': model_setting.get('max_tokens') if 'max_tokens' in model_setting else None,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
'no_references_setting': self.application.dataset_setting.get(
Expand Down
2 changes: 2 additions & 0 deletions apps/application/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
path('application/<str:application_id>/statistics/chat_record_aggregate_trend',
views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()),
path('application/<str:application_id>/model', views.Application.Model.as_view()),
path('application/<str:application_id>/model/<str:model_id>', views.Application.Model.Operate.as_view()),
path('application/<str:application_id>/other-config', views.Application.Model.OtherConfig.as_view()),
path('application/<str:application_id>/hit_test', views.Application.HitTest.as_view()),
path('application/<str:application_id>/api_key', views.Application.ApplicationKey.as_view()),
path("application/<str:application_id>/api_key/<str:api_key_id>",
Expand Down
22 changes: 22 additions & 0 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,28 @@ def get(self, request: Request, application_id: str):
data={'application_id': application_id,
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))

class Operate(APIView):
authentication_classes = [TokenAuth]

@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
operation_id="获取应用参数设置其他字段",
tags=["应用/会话"])
def get(self, request: Request, application_id: str, model_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'model_id': model_id}).get_other_file_list())

class OtherConfig(APIView):
authentication_classes = [TokenAuth]

@swagger_auto_schema(operation_summary="获取应用参数设置其他字段",
operation_id="获取应用参数设置其他字段",
tags=["应用/会话"])
def put(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id}).save_other_config(request.data))

class Profile(APIView):
authentication_classes = [TokenAuth]

Expand Down
8 changes: 4 additions & 4 deletions apps/setting/models_provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants


def get_model_(provider, model_type, model_name, credential):
def get_model_(provider, model_type, model_name, credential, **kwargs):
"""
获取模型实例
@param provider: 供应商
Expand All @@ -25,17 +25,17 @@ def get_model_(provider, model_type, model_name, credential):
model = get_provider(provider).get_model(model_type, model_name,
json.loads(
rsa_long_decrypt(credential)),
streaming=True)
streaming=True, **kwargs)
return model


def get_model(model):
def get_model(model, **kwargs):
"""
获取模型实例
@param model: model 数据库Model实例对象
@return: 模型实例
"""
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
return get_model_(model.provider, model.model_type, model.model_name, model.credential, **kwargs)


def get_provider(provider):
Expand Down
7 changes: 7 additions & 0 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def encryption_dict(self, model_info: Dict[str, object]):
"""
pass

def get_other_fields(self, model_name):
"""
获取其他字段
:return:
"""
pass

@staticmethod
def encryption(message: str):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
VolcanicEngineModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider

Expand All @@ -40,3 +41,4 @@ class ModelProvideConstants(Enum):
model_tencent_provider = TencentModelProvider()
model_aws_bedrock_provider = BedrockModelProvider()
model_local_provider = LocalModelProvider()
model_xinference_provider = XinferenceModelProvider()
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,56 @@ def _get_aws_bedrock_icon_path():


def _initialize_model_info():
model_info_list = [_create_model_info(
'amazon.titan-text-premier-v1:0',
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
model_info_list = [
_create_model_info(
'anthropic.claude-v2:1',
'Claude 2 的更新,采用双倍的上下文窗口,并在长文档和 RAG 上下文中提高可靠性、幻觉率和循证准确性。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-v2',
'Anthropic 功能强大的模型,可处理各种任务,从复杂的对话和创意内容生成到详细的指令服从。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-haiku-20240307-v1:0',
'Claude 3 Haiku 是 Anthropic 最快速、最紧凑的模型,具有近乎即时的响应能力。该模型可以快速回答简单的查询和请求。客户将能够构建模仿人类交互的无缝人工智能体验。 Claude 3 Haiku 可以处理图像和返回文本输出,并且提供 200K 上下文窗口。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-sonnet-20240229-v1:0',
'Anthropic 推出的 Claude 3 Sonnet 模型在智能和速度之间取得理想的平衡,尤其是在处理企业工作负载方面。该模型提供最大的效用,同时价格低于竞争产品,并且其经过精心设计,是大规模部署人工智能的可靠选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-3-5-sonnet-20240620-v1:0',
'Claude 3.5 Sonnet提高了智能的行业标准,在广泛的评估中超越了竞争对手的型号和Claude 3 Opus,具有我们中端型号的速度和成本效益。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'anthropic.claude-instant-v1',
'一种更快速、更实惠但仍然非常强大的模型,它可以处理一系列任务,包括随意对话、文本分析、摘要和文档问题回答。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-premier-v1:0',
'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。',
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel
),
_create_model_info(
'amazon.titan-text-lite-v1',
'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型',
Expand All @@ -59,7 +102,7 @@ def _initialize_model_info():
_create_model_info(
'mistral.mistral-7b-instruct-v0:2',
'7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。',
ModelTypeConst.EMBEDDING,
ModelTypeConst.LLM,
BedrockLLMModelCredential,
BedrockModel),
_create_model_info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,16 @@ def encryption_dict(self, model: Dict[str, object]):
region_name = forms.TextInputField('Region Name', required=True)
access_key_id = forms.TextInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)

def get_other_fields(self, model_name):
return {
'max_tokens': {
'value': 1024,
'min': 1,
'max': 8192,
'step': 1,
'label': '输出最大Tokens',
'precision': 0,
'tooltip': '指定模型可生成的最大token个数'
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional, Iterator
from langchain_community.chat_models import BedrockChat
from langchain_core.messages import BaseMessage, get_buffer_string
from common.config.tokenizer_manage_config import TokenizerManage
from langchain_community.chat_models.bedrock import ChatPromptAdapter
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage, get_buffer_string, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from setting.models_provider.base_model_provider import MaxKBBaseModel


class BedrockModel(MaxKBBaseModel, BedrockChat):

@staticmethod
def is_cache_model():
return False

def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
streaming: bool = False, **kwargs):
super().__init__(model_id=model_id, region_name=region_name,
Expand All @@ -15,21 +21,52 @@ def __init__(self, model_id: str, region_name: str, credentials_profile_name: st
@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']

return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['credentials_profile_name'],
streaming=model_kwargs.pop('streaming', False),
**model_kwargs
**optional_params
)

def _get_num_tokens(self, content: str) -> int:
"""Helper method to count tokens in a string."""
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(content))

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return sum(self._get_num_tokens(get_buffer_string([message])) for message in messages)

def get_num_tokens(self, text: str) -> int:
return self._get_num_tokens(text)

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
provider = self._get_provider()
prompt, system, formatted_messages = None, None, None

if provider == "anthropic":
system, formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)

for chunk in self._prepare_input_and_invoke_stream(
prompt=prompt,
system=system,
messages=formatted_messages,
stop=stop,
run_manager=run_manager,
**kwargs,
):
delta = chunk.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
Loading

0 comments on commit d57f857

Please sign in to comment.