diff --git a/appbuilder/core/components/text_to_image/component.py b/appbuilder/core/components/text_to_image/component.py index 660aa14c3..f5017c875 100644 --- a/appbuilder/core/components/text_to_image/component.py +++ b/appbuilder/core/components/text_to_image/component.py @@ -18,6 +18,7 @@ import json import math +from typing import Optional from appbuilder.core.component import Component from appbuilder.core.message import Message from appbuilder.core._client import HTTPClient @@ -72,38 +73,45 @@ def run( width: int = 1024, height: int = 1024, image_num: int = 1, - timeout: float = None, - retry: int = 0, - request_id: str = None + image: Optional[str] = None, + url: Optional[str] = None, + pdf_file: Optional[str] = None, + pdf_file_num: Optional[str] = None, + change_degree: Optional[int] = None, + text_content: Optional[str] = None, + task_time_out: Optional[int]= None, + text_check: Optional[int] = 1, + request_id: Optional[str] = None ): - """ - 输入文本并返回生成的图片url。 - - 参数: - message (obj:`Message`): 输入消息,用于模型的主要输入内容。这是一个必需的参数。举例: Message(content={"prompt": "上海的经典风景"}) - width (int,可选): 图片宽度,支持:512x512、640x360、360x640、1024x1024、1280x720、720x1280、2048x2048、2560x1440、1440x2560。 - height (int, 可选): 图片高度,支持:512x512、640x360、360x640、1024x1024、1280x720、720x1280、2048x2048、2560x1440、1440x2560。 - image_num (int, 可选): 生成图片数量,默认一张,支持生成 1-8 张。 - timeout (float, 可选): 请求的超时时间。 - retry (int, 可选): 请求的重试次数。 + headers = self._http_client.auth_header() + headers["Content-Type"] = "application/json" + api_url = self._http_client.service_url("/v1/bce/aip/ernievilg/v1/txt2imgv2") + + req = Text2ImageSubmitRequest( + prompt=message.content["prompt"], + width=width, + height=height, + image_num=image_num, + image=image, + url=url, + pdf_file=pdf_file, + pdf_file_num=pdf_file_num, + change_degree=change_degree, + text_content=text_content, + task_time_out=task_time_out, + text_check=text_check + ) + response = self.http_client.session.post(api_url, json=req.model_dump(), headers=headers, timeout=None) + self._http_client.check_response_header(response) + data = response.json() + resp= Text2ImageSubmitResponse(**data) - 返回: - obj:`Message`: 输出生成图片的url。举例: Message(content={"img_urls": ["xxx"]})。 - """ - inp = Text2ImageInMessage(**message.content) - text2ImageSubmitRequest = Text2ImageSubmitRequest() - text2ImageSubmitRequest.prompt = inp.prompt - text2ImageSubmitRequest.width = width - text2ImageSubmitRequest.height = height - text2ImageSubmitRequest.image_num = image_num - text2ImageSubmitResponse = self.submitText2ImageTask(text2ImageSubmitRequest, request_id=request_id) - taskId = text2ImageSubmitResponse.data.primary_task_id + taskId = resp.data.task_id if taskId is not None: task_request_time = 1 while True: - request = Text2ImageQueryRequest() - request.task_id = taskId + request = Text2ImageQueryRequest(task_id=taskId) text2ImageQueryResponse = self.queryText2ImageData(request, request_id=request_id) if text2ImageQueryResponse.data.task_progress is not None: task_progress = float(text2ImageQueryResponse.data.task_progress) @@ -143,19 +151,16 @@ def submitText2ImageTask( obj:`Text2ImageSubmitResponse`: 接口返回的输出消息。 """ url = self.http_client.service_url("/v1/bce/aip/ernievilg/v1/txt2imgv2") - data = Text2ImageSubmitRequest.to_json(request) + data = request.model_dump() headers = self.http_client.auth_header(request_id) headers['content-type'] = 'application/json' if retry != self.http_client.retry.total: self.http_client.retry.total = retry - response = self.http_client.session.post(url, data=data, headers=headers, timeout=timeout) + response = self.http_client.session.post(url, json=data, headers=headers, timeout=timeout) self.http_client.check_response_header(response) data = response.json() self.http_client.check_response_json(data) - request_id = self.http_client.response_request_id(response) - self.__class__.check_service_error(request_id, data) - response = Text2ImageSubmitResponse.from_json(payload=json.dumps(data)) - response.request_id = request_id + response = Text2ImageSubmitResponse(**data) return response def queryText2ImageData( @@ -191,8 +196,7 @@ def queryText2ImageData( self.http_client.check_response_json(data) request_id = self.http_client.response_request_id(response) self.__class__.check_service_error(request_id, data) - response = Text2ImageQueryResponse.from_json(payload=json.dumps(data)) - response.request_id = request_id + response = Text2ImageQueryResponse(**data) return response def extract_img_urls(self, response: Text2ImageQueryResponse): diff --git a/appbuilder/core/components/text_to_image/model.py b/appbuilder/core/components/text_to_image/model.py index ccca09b99..cb087dced 100644 --- a/appbuilder/core/components/text_to_image/model.py +++ b/appbuilder/core/components/text_to_image/model.py @@ -17,260 +17,74 @@ """ from typing import MutableSequence, List +from typing import Optional, Union import proto -from pydantic import BaseModel +from pydantic import BaseModel, Field -class Text2ImageSubmitRequest(proto.Message): - r"""文生图提交任务的请求体。 - - 参数: - prompt(str): - 生图的文本描述。仅支持中文、日常标点符号。不支持英文,特殊符号,限制 200 字。 - width(int): - 图片宽度,支持:512x512、640x360、360x640、1024x1024、1280x720、720x1280、2048x2048、2560x1440、1440x2560。 - height(int): - 图片高度,支持:512x512、640x360、360x640、1024x1024、1280x720、720x1280、2048x2048、2560x1440、1440x2560. - image_num(int): - 生成图片数量,默认一张,支持生成 1-8 张。 - image(string): - 参考图,需 base64 编码,大小不超过 10M,最短边至少 15px,最长边最大 8192px,支持jpg/jpeg/png/bmp 格式。 - 优先级:image > url > pdf_file,当image 字段存在时,url、pdf_file 字段失效。 - url(str): - 参考图完整 url,url 长度不超过 1024 字节,url 对应的图片需 base64 编码,大小不超过 10M,最短边至少 15px, - 最长边最大8192px,支持 jpg/jpeg/png/bmp 格式。优先级:image > url > pdf_file,当image 字段存在时,url 字段失效请注意关闭 URL 防盗链。 - pdf_file(string): - 参考图 PDF 文件,base64 编码,大小不超过10M,最短边至少 15px,最长边最大 8192px 。 - 优先级:image > url > pdf_file,当image 字段存在时,url、pdf_file 字段失效。 - pdf_file_num(str): - 需要识别的 PDF 文件的对应页码,当pdf_file 参数有效时,识别传入页码的对应页面内容,若不传入,则默认识别第 1 页。 - change_degree(int): - 参考图影响因子,支持 1-10 内;数值越大参考图影响越大。 - """ - prompt: str = proto.Field( - proto.STRING, - number=1, - ) - width: int = proto.Field( - proto.INT32, - number=2, - ) - height: int = proto.Field( - proto.INT32, - number=3, - ) - image_num: int = proto.Field( - proto.INT32, - number=4, - optional=True, - ) - image: str = proto.Field( - proto.STRING, - number=5, - optional=True, - ) - url: str = proto.Field( - proto.STRING, - number=6, - optional=True, - ) - pdf_file: str = proto.Field( - proto.STRING, - number=7, - optional=True, - ) - pdf_file_num: str = proto.Field( - proto.STRING, - number=8, - optional=True, - ) - change_degree: int = proto.Field( - proto.INT32, - number=9, - optional=True, - ) - - -class Text2ImageQueryRequest(proto.Message): - r"""文生图生成结果查询请求体。 - - 参数: - task_id(int): - 从提交请求的提交接口的返回值中获取,可使用task_id 查询总任务。 - """ - task_id: int = proto.Field( - proto.INT64, - number=1, - ) +class Text2ImageSubmitRequest(BaseModel): + prompt: str = Field(default='') + width: int = Field(default=1024) + height: int = Field(default=1024) + image_num: int = Field(default=1, ge=1, le=8) + image: Optional[str] = Field(default="") + url: Optional[str] = Field(default="") + pdf_file: Optional[str] = Field(default="") + pdf_file_num: Optional[str] = Field(default="") + change_degree: Optional[int] = None + text_content: Optional[str] = None + task_time_out: Optional[int] = None + text_check: Optional[int] = None -class Text2ImageSubmitResponse(proto.Message): - r"""文生图任务提交接口返回体。 +class Text2ImageSubmitErrorDetail(BaseModel): + msg: Optional[str] + word: Optional[object] - 参数: - request_id(str): - 网关层的请求ID。 - log_id(str): - 算子层请求唯一标识码。 - data(Text2ImageSubmitData): - 任务提交接口返回数据。 - """ - request_id: str = proto.Field( - proto.STRING, - number=1, - ) - log_id: int = proto.Field( - proto.INT64, - number=2, - ) - data: "Text2ImageSubmitData" = proto.Field( - proto.MESSAGE, - number=3, - message="Text2ImageSubmitData", - ) +class Text2ImageSubmitResponseData(BaseModel): + primary_task_id: Optional[int] = None + task_id: Optional[str] = None -class Text2ImageQueryResponse(proto.Message): - r"""文生图任务结果查询接口返回体。. +class Text2ImageSubmitResponse(BaseModel): + log_id: Optional[int] = None + data: Optional[Text2ImageSubmitResponseData] = Text2ImageSubmitResponseData() + error_msg: Optional[str] = None + error_detail: Optional[Text2ImageSubmitErrorDetail] = None + error_code: Optional[int] = None - 参数: - request_id(str): - Request ID of gateway layer. - log_id(str): - Request ID of service layer. - data(Text2ImageQueryData): - Text to Image query response data . - """ - request_id: str = proto.Field( - proto.STRING, - number=1, - ) - log_id: int = proto.Field( - proto.INT64, - number=2, - ) +class Text2ImageQueryRequest(BaseModel): + task_id: Optional[str] - data: "Text2ImageQueryData" = proto.Field( - proto.MESSAGE, - number=3, - message="Text2ImageQueryData", - ) +class FinalImage(BaseModel): + img_url: Optional[str] = None + height: Optional[int] = None + width: Optional[int] = None + img_approve_conclusion: Optional[str] = None -class Text2ImageSubmitData(proto.Message): - r"""文生图提交任务接口返回体数据。 - 参数: - primary_task_id(str): - 生成图片任务long类型 id,与“task_id”参数输出相同,该 id 可用于查询任务状态。 - task_id(str): - 生成图片任务string类型 id,与“primary_task_id”参数输出相同,该 id 可用于查询任务状态。 - """ - primary_task_id: int = proto.Field( - proto.INT64, - number=1, - ) +class SubTaskResult(BaseModel): + sub_task_status: Optional[str] = None + sub_task_progress_detail: Union[int, float, None] = None + sub_task_progress: Union[float, int, None] = None + sub_task_error_code: Optional[int] = None + final_image_list: Optional[list[FinalImage]] = None - task_id: str = proto.Field( - proto.STRING, - number=2, - ) +class Text2ImageQueryResponseData(BaseModel): + task_id: Optional[int] = None + task_status: Optional[str] = None + task_progress_detail: Union[float, int, None] = None + task_progress: Union[float, int, None] = None + sub_task_result_list: Optional[list[SubTaskResult]] = None -class Text2ImageQueryData(proto.Message): - r"""文生图任务查询接口返回体数据。 - 参数: - task_id(int): - 任务 ID. - task_status(str): - 计算总状态。有 INIT(初始化),WAIT(排队中), RUNNING(生成中), FAILED(失败), SUCCESS(成功)四种状态,只有 SUCCESS 为成功状态。 - task_progress(float): - 图片生成总进度,0到1之间的浮点数表示进度,0为未处理完,1为处理完成。 - sub_task_result_list(Text2ImageSubTaskResultList): - 子任务生成结果列表。 - """ - task_id: int = proto.Field( - proto.INT64, - number=1, - ) - task_status: str = proto.Field( - proto.STRING, - number=2, - ) - task_progress: float = proto.Field( - proto.FLOAT, - number=3, - ) - - sub_task_result_list: MutableSequence["Text2ImageSubTaskResultList"] = proto.RepeatedField( - proto.MESSAGE, - number=4, - message="Text2ImageSubTaskResultList", - ) - - -class Text2ImageSubTaskResultList(proto.Message): - r"""文生图子任务结果列表。 - 参数: - sub_task_status(int): - 单风格图片状态。有 INIT(初始化),WAIT(排队中), RUNNING(生成中), FAILED(失败), SUCCESS(成功)四种状态,只有 SUCCESS 为成功状态。 - sub_task_progress(float): - 单任务图片生成进度,0到1之间的浮点数表示进度,0为未处理完,1为处理完成。 - sub_task_error_code(str): - 单风格任务错误码。0:正常;501:文本黄反拦截;201:模型生图失败。 - final_image_list(Text2ImageFinalImageList): - 单风格任务产出的最终图列表。 - """ - sub_task_status: str = proto.Field( - proto.STRING, - number=1, - ) - sub_task_progress: float = proto.Field( - proto.FLOAT, - number=2, - ) - sub_task_error_code: int = proto.Field( - proto.INT32, - number=3, - ) - final_image_list: MutableSequence["Text2ImageFinalImageList"] = proto.RepeatedField( - proto.MESSAGE, - number=4, - message="Text2ImageFinalImageList", - ) - - -class Text2ImageFinalImageList(proto.Message): - r"""文生图单风格任务产出的最终图列表。 - 参数: - img_approve_conclusion(str): - 图片机审结果,"block":输出图片违规;"review": 输出图片疑似违规;"pass": 输出图片未发现问题。 - img_url(str): - 图片所在 BOS http 地址,默认 1 小时失效。 - height(int): - 图片像素信息-高度。 - width(int): - 图片像素信息-宽度。 - """ - img_approve_conclusion: str = proto.Field( - proto.STRING, - number=1, - ) - img_url: str = proto.Field( - proto.STRING, - number=2, - ) - width: int = proto.Field( - proto.INT32, - number=3, - ) - height: int = proto.Field( - proto.INT32, - number=4, - ) +class Text2ImageQueryResponse(BaseModel): + log_id: Union[str, int, None] = None + data: Optional[Text2ImageQueryResponseData] = Text2ImageQueryResponseData() class Text2ImageInMessage(BaseModel): diff --git a/appbuilder/tests/test_text_to_image.py b/appbuilder/tests/test_text_to_image.py index ed48fbc43..5c55aeb67 100644 --- a/appbuilder/tests/test_text_to_image.py +++ b/appbuilder/tests/test_text_to_image.py @@ -2,7 +2,7 @@ import os import appbuilder from appbuilder.core.components.text_to_image.model import (Text2ImageSubmitRequest, Text2ImageSubmitResponse, - Text2ImageQueryRequest, Text2ImageQueryResponse) + Text2ImageQueryRequest, Text2ImageQueryResponse, SubTaskResult) from appbuilder.core._exception import RiskInputException @@ -33,6 +33,7 @@ def test_run(self): """ inp = appbuilder.Message(content={"prompt": "上海的经典风景"}) out = self.text2Image.run(inp) + print(out) self.assertIsNotNone(out) self.assertIsInstance(out, appbuilder.Message) @@ -67,8 +68,9 @@ def test_queryText2ImageData(self): None """ - request = Text2ImageQueryRequest() - request.task_id = '123456' + request = Text2ImageQueryRequest( + task_id = "123456", + ) response = self.text2Image.queryText2ImageData(request) self.assertIsNotNone(response) self.assertIsInstance(response, Text2ImageQueryResponse) @@ -86,7 +88,8 @@ def test_extract_img_urls(self): """ response = Text2ImageQueryResponse() response.data.task_progress = 1.0 - response.data.sub_task_result_list = [{'final_image_list': [{'img_url': 'http://example.com'}]}] + response.data.task_progress_detail = 0.5 + response.data.sub_task_result_list = [SubTaskResult(**{'sub_task_progress_detail':0.8, 'final_image_list': [{'img_url': 'http://example.com'}]})] img_urls = self.text2Image.extract_img_urls(response) self.assertEqual(img_urls, ['http://example.com'])