From 35105dc16456097e32a1d7c2f0be246e7d17b39b Mon Sep 17 00:00:00 2001 From: zhoulongchao <152879727+longchao1916@users.noreply.github.com> Date: Tue, 19 Mar 2024 14:04:21 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E4=E4=B8=AAFC=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E7=BB=84=E4=BB=B6=20(#191)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 新增4个FC工具组件 --- appbuilder/__init__.py | 6 +- appbuilder/core/components/asr/component.py | 81 ++++++++++++++++++- .../core/components/general_ocr/component.py | 72 ++++++++++++++++- .../components/object_recognize/component.py | 80 +++++++++++++++++- .../retriever/baidu_vdb/baiduvdb_retriever.py | 3 +- .../core/components/translate/component.py | 58 ++++++++++++- appbuilder/tests/test_asr.py | 13 +++ appbuilder/tests/test_general_ocr.py | 17 ++++ appbuilder/tests/test_object_recognize.py | 17 ++++ appbuilder/tests/test_translate.py | 13 +++ 10 files changed, 351 insertions(+), 9 deletions(-) diff --git a/appbuilder/__init__.py b/appbuilder/__init__.py index eb22829f3..08a37b982 100644 --- a/appbuilder/__init__.py +++ b/appbuilder/__init__.py @@ -61,8 +61,8 @@ def check_version(self): from .core.components.extract_table.component import ExtractTableFromDoc from .core.components.doc_parser.doc_parser import DocParser, ParserConfig from .core.components.doc_splitter.doc_splitter import DocSplitter -from .core.components.retriever.bes_retriever import BESRetriever -from .core.components.retriever.bes_retriever import BESVectorStoreIndex +from .core.components.retriever.bes.bes_retriever import BESRetriever +from .core.components.retriever.bes.bes_retriever import BESVectorStoreIndex from .core.components.dish_recognize.component import DishRecognition from .core.components.translate.component import Translation from .core.components.animal_recognize.component import AnimalRecognition @@ -81,7 +81,6 @@ def check_version(self): from .core.components.image_understand.component import ImageUnderstand from .core.components.mix_card_ocr.component import MixCardOCR - from appbuilder.core.message import Message from appbuilder.core.agent import AgentRuntime from appbuilder.core.user_session import UserSession @@ -90,7 +89,6 @@ def check_version(self): from appbuilder.core.utils import get_model_list - from .core._exception import ( BadRequestException, ForbiddenException, diff --git a/appbuilder/core/components/asr/component.py b/appbuilder/core/components/asr/component.py index dc60cd8e2..e5ce8507f 100644 --- a/appbuilder/core/components/asr/component.py +++ b/appbuilder/core/components/asr/component.py @@ -14,12 +14,17 @@ r"""ASR component. """ +import os import uuid import json +import proto +import requests + +from appbuilder.core import utils from appbuilder.core.component import Component from appbuilder.core.message import Message -from appbuilder.core._exception import AppBuilderServerException +from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError from appbuilder.core._client import HTTPClient from appbuilder.core.components.asr.model import ShortSpeechRecognitionRequest, ShortSpeechRecognitionResponse, \ ASRInMsg, ASROutMsg @@ -44,6 +49,41 @@ class ASR(Component): out = asr.run(msg) print(out.content) # eg: {"result": ["北京科技馆。"]} """ + name = "asr" + version = "v1" + + manifests = [ + { + "name": "asr", + "description": "对于输入的语音文件进行识别,输出语音识别结果。", + "parameters": { + "type": "object", + "properties": { + "file_url": { + "type": "string", + "description": "输入语音文件的url,根据url获取到语音文件" + }, + "file_name": { + "type": "string", + "description": "待识别语音文件名,用于生成获取语音的url" + } + }, + "anyOf": [ + { + "required": [ + "file_url" + ] + }, + { + "required": [ + "file_name" + ] + } + ] + } + } + ] + @HTTPClient.check_param def run(self, message: Message, audio_format: str = "pcm", rate: int = 16000, timeout: float = None, retry: int = 0) -> Message: @@ -120,3 +160,42 @@ def _check_service_error(request_id: str, data: dict): service_err_code=data["err_no"], service_err_message=data["err_msg"] ) + + def tool_eval(self, name: str, streaming: bool, **kwargs): + """ + asr for function call + """ + file_url = kwargs.get("file_url", None) + if not file_url: + file_urls = kwargs.get("file_urls", {}) + file_path = kwargs.get("file_name", None) + if not file_path: + raise InvalidRequestArgumentError("file name is not set") + file_name = os.path.basename(file_path) + file_url = file_urls.get(file_name, None) + if not file_url: + raise InvalidRequestArgumentError(f"file {file_url} url does not exist") + req = ShortSpeechRecognitionRequest() + req.cuid = str(uuid.uuid4()) + req.dev_pid = "80001" + req.speech = requests.get(file_url).content + req.format = "pcm" + req.rate = 16000 + result = proto.Message.to_dict(self._recognize(req)) + results = { + "识别结果": " \n".join(item for item in result["result"]) + } + res = json.dumps(results, ensure_ascii=False, indent=4) + if streaming: + yield { + "type": "text", + "text": res, + "visible_scope": 'llm', + } + yield { + "type": "text", + "text": "", + "visible_scope": 'user', + } + else: + return res diff --git a/appbuilder/core/components/general_ocr/component.py b/appbuilder/core/components/general_ocr/component.py index 301038114..82a7be5d7 100644 --- a/appbuilder/core/components/general_ocr/component.py +++ b/appbuilder/core/components/general_ocr/component.py @@ -13,11 +13,13 @@ r"""general ocr component.""" import base64 import json +import os.path +from appbuilder.core import utils from appbuilder.core._client import HTTPClient -from appbuilder.core._exception import AppBuilderServerException +from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError from appbuilder.core.component import Component from appbuilder.core.components.general_ocr.model import * from appbuilder.core.message import Message @@ -44,6 +46,40 @@ class GeneralOCR(Component): print(out.content) """ + name = "general_ocr" + version = "v1" + + manifests = [ + { + "name": "general_ocr", + "description": "提供更高精度的通用文字识别能力,能够识别图片中的文字", + "parameters": { + "type": "object", + "properties": { + "img_url": { + "type": "string", + "description": "待识别图片的url,根据该url能够获取图片" + }, + "img_name": { + "type": "string", + "description": "待识别图片的文件名,用于生成图片url" + }, + }, + "anyOf": [ + { + "required": [ + "img_url" + ] + }, + { + "required": [ + "img_name" + ] + } + ] + } + } + ] @HTTPClient.check_param def run(self, message: Message, timeout: float = None, retry: int = 0) -> Message: @@ -110,3 +146,37 @@ def _check_service_error(request_id: str, data: dict): service_err_code=data.get("error_code"), service_err_message=data.get("error_msg") ) + + def tool_eval(self, name: str, streaming: bool, **kwargs): + """ + general_ocr for function call + """ + img_url = kwargs.get("img_url", None) + if not img_url: + file_urls = kwargs.get("file_urls", {}) + img_path = kwargs.get("img_name", None) + if not img_path: + raise InvalidRequestArgumentError("file name is not set") + img_name = os.path.basename(img_path) + img_url = file_urls.get(img_name, None) + if not img_url: + raise InvalidRequestArgumentError(f"file {img_name} url does not exist") + req = GeneralOCRRequest(url=img_url) + result = proto.Message.to_dict(self._recognize(req)) + results = { + "识别结果": " \n".join(item["words"] for item in result["words_result"]) + } + res = json.dumps(results, ensure_ascii=False, indent=4) + if streaming: + yield { + "type": "text", + "text": res, + "visible_scope": 'llm', + } + yield { + "type": "text", + "text": "", + "visible_scope": 'user', + } + else: + return res \ No newline at end of file diff --git a/appbuilder/core/components/object_recognize/component.py b/appbuilder/core/components/object_recognize/component.py index 5f5b58ba2..405fafb77 100644 --- a/appbuilder/core/components/object_recognize/component.py +++ b/appbuilder/core/components/object_recognize/component.py @@ -14,11 +14,13 @@ import base64 import json +import os +from appbuilder.core import utils from appbuilder.core._client import HTTPClient from appbuilder.core.component import Component from appbuilder.core.message import Message -from appbuilder.core._exception import AppBuilderServerException +from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError from appbuilder.core.components.object_recognize.model import * @@ -41,6 +43,40 @@ class ObjectRecognition(Component): print(out.content) """ + name = "object_recognition" + version = "v1" + + manifests = [ + { + "name": "object_recognition", + "description": "提供通用物体及场景识别能力,即对于输入的一张图片,输出图片中的多个物体及场景标签。", + "parameters": { + "type": "object", + "properties": { + "img_url": { + "type": "string", + "description": "待识别图片的url,根据该url能够获取图片" + }, + "img_name": { + "type": "string", + "description": "待识别图片的文件名,用于生成图片url" + } + }, + "anyOf": [ + { + "required": [ + "img_url" + ] + }, + { + "required": [ + "img_name" + ] + } + ] + } + } + ] @HTTPClient.check_param def run(self, message: Message, timeout: float = None, retry: int = 0) -> Message: @@ -108,3 +144,45 @@ def _check_service_error(request_id: str, data: dict): service_err_code=data.get("error_code"), service_err_message=data.get("error_msg") ) + + def tool_eval(self, name: str, streaming: bool, **kwargs): + """ + object_recognize for function call + """ + img_url = kwargs.get("img_url", None) + if not img_url: + file_urls = kwargs.get("file_urls", {}) + img_path = kwargs.get("img_name", None) + if not img_path: + raise InvalidRequestArgumentError("file name is not set") + img_name = os.path.basename(img_path) + img_url = file_urls.get(img_name, None) + if not img_url: + raise InvalidRequestArgumentError(f"file {img_name} url does not exist") + score_threshold = kwargs.get("score_threshold", 0.5) + req = ObjectRecognitionRequest(url=img_url) + result = proto.Message.to_dict(self._recognize(req)) + results = [] + for item in result["result"]: + if item["score"] < score_threshold and len(results) > 0: + continue + res = { + "物品名称": item["keyword"], + "置信度": item["score"], + "所属类别": item["root"], + } + results.append(res) + res = json.dumps(results, ensure_ascii=False, indent=4) + if streaming: + yield { + "type": "text", + "text": res, + "visible_scope": 'llm', + } + yield { + "type": "text", + "text": "", + "visible_scope": 'user', + } + else: + return res \ No newline at end of file diff --git a/appbuilder/core/components/retriever/baidu_vdb/baiduvdb_retriever.py b/appbuilder/core/components/retriever/baidu_vdb/baiduvdb_retriever.py index d716f05f1..811d9e223 100644 --- a/appbuilder/core/components/retriever/baidu_vdb/baiduvdb_retriever.py +++ b/appbuilder/core/components/retriever/baidu_vdb/baiduvdb_retriever.py @@ -136,7 +136,8 @@ def _init_client(self, instance_id, account, api_key): from pymochow.auth.bce_credentials import AppBuilderCredentials gateway = os.getenv("GATEWAY_URL") if os.getenv("GATEWAY_URL") else GATEWAY_URL - + appbuilder_token = os.getenv("APPBUILDER_TOKEN") + config = Configuration( credentials=AppBuilderCredentials(account, api_key, appbuilder_token), endpoint=gateway, diff --git a/appbuilder/core/components/translate/component.py b/appbuilder/core/components/translate/component.py index ddd60508a..cb6759594 100644 --- a/appbuilder/core/components/translate/component.py +++ b/appbuilder/core/components/translate/component.py @@ -22,7 +22,7 @@ from appbuilder.core.message import Message from appbuilder.core.component import Component from appbuilder.core._client import HTTPClient -from appbuilder.core._exception import AppBuilderServerException +from appbuilder.core._exception import AppBuilderServerException, InvalidRequestArgumentError from appbuilder.core.components.translate.model import * @@ -50,6 +50,30 @@ class Translation(Component): name = "translate" version = "v1" + manifests = [ + { + "name": "translation", + "description": "文本翻译通用版工具,会根据指定的目标语言对文本进行翻译,并返回翻译后的文本。", + "parameters": { + "type": "object", + "properties": { + "q": { + "type": "string", + "description": "需要翻译的源文本,文本翻译工具会将该文本翻译成对应的目标语言" + }, + "to_lang": { + "type": "string", + "description": "翻译的目标语言类型,'en'表示将原文本翻译成英文, 'zh'表示将原文本翻译成中文,默认为'en'", + "enum": ["en", "zh"] + } + }, + "required": [ + "q" + ] + } + } + ] + @HTTPClient.check_param def run(self, message: Message, from_lang: str = "auto", to_lang: str = "en", timeout: float = None, retry: int = 0) -> Message: @@ -113,3 +137,35 @@ def _translate(self, request: TranslateRequest, timeout: float = None, json_str = json.dumps(data) return TranslateResponse(TranslateResponse.from_json(json_str)) + + def tool_eval(self, name: str, streaming: bool, **kwargs): + """ + translate for function call + """ + req = TranslateRequest() + text = kwargs.get("q", None) + if not text: + raise InvalidRequestArgumentError("param `q` must be set") + req.q = text + to_lang = kwargs.get("to_lang", "en") + req.to_lang = to_lang + results = proto.Message.to_dict(self._translate(req))["result"] + trans_result = results["trans_result"] + res = { + "原文本": "\n ".join(item["src"] for item in trans_result), + "翻译结果": "\n ".join(item["dst"] for item in trans_result) + } + res = json.dumps(res, ensure_ascii=False, indent=4) + if streaming: + yield { + "type": "text", + "text": res, + "visible_scope": 'llm', + } + yield { + "type": "text", + "text": "", + "visible_scope": 'user', + } + else: + return res diff --git a/appbuilder/tests/test_asr.py b/appbuilder/tests/test_asr.py index 029888656..a6ab77eb2 100644 --- a/appbuilder/tests/test_asr.py +++ b/appbuilder/tests/test_asr.py @@ -4,6 +4,7 @@ import requests import appbuilder +from appbuilder.core._exception import InvalidRequestArgumentError from appbuilder.core.components.asr.model import ShortSpeechRecognitionRequest, ShortSpeechRecognitionResponse @@ -129,6 +130,18 @@ def test_check_service_error(self): data = {'err_msg': 'No Error', 'err_no': 0} self.assertIsNone(self.asr._check_service_error("", data)) + def test_tool_eval_valid(self): + """测试 tool 方法对有效请求的处理。""" + result = self.asr.tool_eval(name="asr", streaming=True, file_url=self.audio_file_url) + res = [item for item in result] + self.assertNotEqual(len(res), 0) + + def test_tool_eval_invalid(self): + """测试 tool 方法对无效请求的处理。""" + with self.assertRaises(InvalidRequestArgumentError): + result = self.asr.tool_eval(name="asr", streaming=True) + next(result) + if __name__ == '__main__': unittest.main() diff --git a/appbuilder/tests/test_general_ocr.py b/appbuilder/tests/test_general_ocr.py index 3ec8779b2..b0e822ec3 100644 --- a/appbuilder/tests/test_general_ocr.py +++ b/appbuilder/tests/test_general_ocr.py @@ -15,6 +15,7 @@ import unittest import requests import appbuilder +from appbuilder.core._exception import InvalidRequestArgumentError class TestGeneralOCR(unittest.TestCase): @@ -131,6 +132,22 @@ def test_run_without_image_and_url(self): with self.assertRaises(ValueError): self.general_ocr.run(message=message) + def test_tool_eval_valid(self): + """测试 tool 方法对有效请求的处理。""" + image_url = "https://bj.bcebos.com/v1/appbuilder/general_ocr_test.png?" \ + "authorization=bce-auth-v1%2FALTAKGa8m4qCUasgoljdEDAzLm%2F2024-01-" \ + "11T10%3A59%3A17Z%2F-1%2Fhost%2F081bf7bcccbda5207c82a4de074628b04ae" \ + "857a27513734d765495f89ffa5f73" + result = self.general_ocr.tool_eval(name="general_ocr", streaming=True, img_url=image_url) + res = [item for item in result] + self.assertNotEqual(len(res), 0) + + def test_tool_eval_invalid(self): + """测试 tool 方法对无效请求的处理。""" + with self.assertRaises(InvalidRequestArgumentError): + result = self.general_ocr.tool_eval(name="general_ocr", streaming=True) + next(result) + if __name__ == '__main__': unittest.main() diff --git a/appbuilder/tests/test_object_recognize.py b/appbuilder/tests/test_object_recognize.py index 68f1e10f1..e022de08f 100644 --- a/appbuilder/tests/test_object_recognize.py +++ b/appbuilder/tests/test_object_recognize.py @@ -15,6 +15,7 @@ import unittest import requests import appbuilder +from appbuilder.core._exception import InvalidRequestArgumentError class TestObjectRecognize(unittest.TestCase): @@ -131,6 +132,22 @@ def test_run_without_image_and_url(self): with self.assertRaises(ValueError): self.object_recognition.run(message=message) + def test_tool_eval_valid(self): + """测试 tool 方法对有效请求的处理。""" + image_url = "https://bj.bcebos.com/v1/appbuilder/object_recognize_test.png?" \ + "authorization=bce-auth-v1%2FALTAKGa8m4qCUasgoljdEDAzLm%2F2024-01-" \ + "11T11%3A00%3A19Z%2F-1%2Fhost%2F2c31bf29205f61e58df661dc80af31a1dc" \ + "1ba1de0a8f072bc5a87102bd32f9e3" + result = self.object_recognition.tool_eval(name="object_recognition", streaming=True, img_url=image_url) + res = [item for item in result] + self.assertNotEqual(len(res), 0) + + def test_tool_eval_invalid(self): + """测试 tool 方法对无效请求的处理。""" + with self.assertRaises(InvalidRequestArgumentError): + result = self.object_recognition.tool_eval(name="object_recognition", streaming=True) + next(result) + if __name__ == '__main__': unittest.main() diff --git a/appbuilder/tests/test_translate.py b/appbuilder/tests/test_translate.py index 3dd63f46b..5e643b6ce 100644 --- a/appbuilder/tests/test_translate.py +++ b/appbuilder/tests/test_translate.py @@ -1,5 +1,6 @@ import unittest import appbuilder +from appbuilder.core._exception import InvalidRequestArgumentError class TestTranslationComponent(unittest.TestCase): @@ -19,6 +20,18 @@ def test_run_invalid_request(self): with self.assertRaises(ValueError): _ = self.translation(msg) + def test_tool_eval_valid(self): + """测试 tool 方法对有效请求的处理。""" + result = self.translation.tool_eval(name="translation", streaming=True, q="你好\n中国", to_lang="en") + res = [item for item in result] + self.assertNotEqual(len(res), 0) + + def test_tool_eval_invalid(self): + """测试 tool 方法对无效请求的处理。""" + with self.assertRaises(InvalidRequestArgumentError): + result = self.translation.tool_eval(name="translation", streaming=True, to_lang="en") + next(result) + if __name__ == '__main__': unittest.main()