Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

新增4个FC工具组件 #191

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions appbuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def check_version(self):
from .core.components.extract_table.component import ExtractTableFromDoc
from .core.components.doc_parser.doc_parser import DocParser, ParserConfig
from .core.components.doc_splitter.doc_splitter import DocSplitter
from .core.components.retriever.bes_retriever import BESRetriever
from .core.components.retriever.bes_retriever import BESVectorStoreIndex
from .core.components.retriever.bes.bes_retriever import BESRetriever
from .core.components.retriever.bes.bes_retriever import BESVectorStoreIndex
from .core.components.dish_recognize.component import DishRecognition
from .core.components.translate.component import Translation
from .core.components.animal_recognize.component import AnimalRecognition
Expand All @@ -81,7 +81,6 @@ def check_version(self):
from .core.components.image_understand.component import ImageUnderstand
from .core.components.mix_card_ocr.component import MixCardOCR


from appbuilder.core.message import Message
from appbuilder.core.agent import AgentRuntime
from appbuilder.core.user_session import UserSession
Expand All @@ -90,7 +89,6 @@ def check_version(self):

from appbuilder.core.utils import get_model_list


from .core._exception import (
BadRequestException,
ForbiddenException,
Expand Down
81 changes: 80 additions & 1 deletion appbuilder/core/components/asr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@

r"""ASR component.
"""
import os
import uuid
import json

import proto
import requests

from appbuilder.core import utils
from appbuilder.core.component import Component
from appbuilder.core.message import Message
from appbuilder.core._exception import AppBuilderServerException
from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError
from appbuilder.core._client import HTTPClient
from appbuilder.core.components.asr.model import ShortSpeechRecognitionRequest, ShortSpeechRecognitionResponse, \
ASRInMsg, ASROutMsg
Expand All @@ -44,6 +49,41 @@ class ASR(Component):
out = asr.run(msg)
print(out.content) # eg: {"result": ["北京科技馆。"]}
"""
name = "asr"
version = "v1"

manifests = [
{
"name": "asr",
"description": "对于输入的语音文件进行识别,输出语音识别结果。",
"parameters": {
"type": "object",
"properties": {
"file_url": {
"type": "string",
"description": "输入语音文件的url,根据url获取到语音文件"
},
"file_name": {
"type": "string",
"description": "待识别语音文件名,用于生成获取语音的url"
}
},
"anyOf": [
{
"required": [
"file_url"
]
},
{
"required": [
"file_name"
]
}
]
}
}
]

@HTTPClient.check_param
def run(self, message: Message, audio_format: str = "pcm", rate: int = 16000,
timeout: float = None, retry: int = 0) -> Message:
Expand Down Expand Up @@ -120,3 +160,42 @@ def _check_service_error(request_id: str, data: dict):
service_err_code=data["err_no"],
service_err_message=data["err_msg"]
)

def tool_eval(self, name: str, streaming: bool, **kwargs):
"""
asr for function call
"""
file_url = kwargs.get("file_url", None)
if not file_url:
file_urls = kwargs.get("file_urls", {})
file_path = kwargs.get("file_name", None)
if not file_path:
raise InvalidRequestArgumentError("file name is not set")
file_name = os.path.basename(file_path)
file_url = file_urls.get(file_name, None)
if not file_url:
raise InvalidRequestArgumentError(f"file {file_url} url does not exist")
req = ShortSpeechRecognitionRequest()
req.cuid = str(uuid.uuid4())
req.dev_pid = "80001"
req.speech = requests.get(file_url).content
req.format = "pcm"
req.rate = 16000
result = proto.Message.to_dict(self._recognize(req))
results = {
"识别结果": " \n".join(item for item in result["result"])
}
res = json.dumps(results, ensure_ascii=False, indent=4)
if streaming:
yield {
"type": "text",
"text": res,
"visible_scope": 'llm',
}
yield {
"type": "text",
"text": "",
"visible_scope": 'user',
}
else:
return res
72 changes: 71 additions & 1 deletion appbuilder/core/components/general_ocr/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
r"""general ocr component."""
import base64
import json
import os.path

from appbuilder.core import utils
from appbuilder.core._client import HTTPClient


from appbuilder.core._exception import AppBuilderServerException
from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError
from appbuilder.core.component import Component
from appbuilder.core.components.general_ocr.model import *
from appbuilder.core.message import Message
Expand All @@ -44,6 +46,40 @@ class GeneralOCR(Component):
print(out.content)

"""
name = "general_ocr"
version = "v1"

manifests = [
{
"name": "general_ocr",
"description": "提供更高精度的通用文字识别能力,能够识别图片中的文字",
"parameters": {
"type": "object",
"properties": {
"img_url": {
"type": "string",
"description": "待识别图片的url,根据该url能够获取图片"
},
"img_name": {
"type": "string",
"description": "待识别图片的文件名,用于生成图片url"
},
},
"anyOf": [
{
"required": [
"img_url"
]
},
{
"required": [
"img_name"
]
}
]
}
}
]

@HTTPClient.check_param
def run(self, message: Message, timeout: float = None, retry: int = 0) -> Message:
Expand Down Expand Up @@ -110,3 +146,37 @@ def _check_service_error(request_id: str, data: dict):
service_err_code=data.get("error_code"),
service_err_message=data.get("error_msg")
)

def tool_eval(self, name: str, streaming: bool, **kwargs):
"""
general_ocr for function call
"""
img_url = kwargs.get("img_url", None)
if not img_url:
file_urls = kwargs.get("file_urls", {})
img_path = kwargs.get("img_name", None)
if not img_path:
raise InvalidRequestArgumentError("file name is not set")
img_name = os.path.basename(img_path)
img_url = file_urls.get(img_name, None)
if not img_url:
raise InvalidRequestArgumentError(f"file {img_name} url does not exist")
req = GeneralOCRRequest(url=img_url)
result = proto.Message.to_dict(self._recognize(req))
results = {
"识别结果": " \n".join(item["words"] for item in result["words_result"])
}
res = json.dumps(results, ensure_ascii=False, indent=4)
if streaming:
yield {
"type": "text",
"text": res,
"visible_scope": 'llm',
}
yield {
"type": "text",
"text": "",
"visible_scope": 'user',
}
else:
return res
80 changes: 79 additions & 1 deletion appbuilder/core/components/object_recognize/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

import base64
import json
import os

from appbuilder.core import utils
from appbuilder.core._client import HTTPClient
from appbuilder.core.component import Component
from appbuilder.core.message import Message
from appbuilder.core._exception import AppBuilderServerException
from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError
from appbuilder.core.components.object_recognize.model import *


Expand All @@ -41,6 +43,40 @@ class ObjectRecognition(Component):
print(out.content)

"""
name = "object_recognition"
version = "v1"

manifests = [
{
"name": "object_recognition",
"description": "提供通用物体及场景识别能力,即对于输入的一张图片,输出图片中的多个物体及场景标签。",
"parameters": {
"type": "object",
"properties": {
"img_url": {
"type": "string",
"description": "待识别图片的url,根据该url能够获取图片"
},
"img_name": {
"type": "string",
"description": "待识别图片的文件名,用于生成图片url"
}
},
"anyOf": [
{
"required": [
"img_url"
]
},
{
"required": [
"img_name"
]
}
]
}
}
]

@HTTPClient.check_param
def run(self, message: Message, timeout: float = None, retry: int = 0) -> Message:
Expand Down Expand Up @@ -108,3 +144,45 @@ def _check_service_error(request_id: str, data: dict):
service_err_code=data.get("error_code"),
service_err_message=data.get("error_msg")
)

def tool_eval(self, name: str, streaming: bool, **kwargs):
"""
object_recognize for function call
"""
img_url = kwargs.get("img_url", None)
if not img_url:
file_urls = kwargs.get("file_urls", {})
img_path = kwargs.get("img_name", None)
if not img_path:
raise InvalidRequestArgumentError("file name is not set")
img_name = os.path.basename(img_path)
img_url = file_urls.get(img_name, None)
if not img_url:
raise InvalidRequestArgumentError(f"file {img_name} url does not exist")
score_threshold = kwargs.get("score_threshold", 0.5)
req = ObjectRecognitionRequest(url=img_url)
result = proto.Message.to_dict(self._recognize(req))
results = []
for item in result["result"]:
if item["score"] < score_threshold and len(results) > 0:
continue
res = {
"物品名称": item["keyword"],
"置信度": item["score"],
"所属类别": item["root"],
}
results.append(res)
res = json.dumps(results, ensure_ascii=False, indent=4)
if streaming:
yield {
"type": "text",
"text": res,
"visible_scope": 'llm',
}
yield {
"type": "text",
"text": "",
"visible_scope": 'user',
}
else:
return res
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def _init_client(self, instance_id, account, api_key):
from pymochow.auth.bce_credentials import AppBuilderCredentials

gateway = os.getenv("GATEWAY_URL") if os.getenv("GATEWAY_URL") else GATEWAY_URL

appbuilder_token = os.getenv("APPBUILDER_TOKEN")

config = Configuration(
credentials=AppBuilderCredentials(account, api_key, appbuilder_token),
endpoint=gateway,
Expand Down
Loading
Loading