Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

组件标准化单测框架更新: 更新系统变量,增加tool_eval参数和manifests匹配性检查 #672

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 150 additions & 96 deletions python/tests/component_check.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
import json
import os
import inspect
import time
from jsonschema import validate, ValidationError, SchemaError
from jsonschema import validate
from pydantic import BaseModel
from typing import Generator
from appbuilder.utils.func_utils import Singleton
from appbuilder.tests.component_schemas import type_to_json_schemas
from appbuilder.utils.json_schema_to_model import json_schema_to_pydantic_model
Expand Down Expand Up @@ -40,12 +38,15 @@ def register_rule(self, rule_name: str, rule_obj: RuleBase):
def remove_rule(self, rule_name: str):
del self.rules[rule_name]

def notify(self, component_cls) -> tuple[bool, list]:
def notify(self, component_cls, component_case) -> tuple[bool, list]:
check_pass = True
check_details = {}
reasons = []
for rule_name, rule_obj in self.rules.items():
res = rule_obj.check(component_cls)
if rule_name == "ToolEvalOutputJsonRule":
res = rule_obj.check(component_cls, component_case)
else:
res = rule_obj.check(component_cls)
check_details[rule_name] = res
if res.check_result == False:
check_pass = False
Expand All @@ -63,53 +64,40 @@ class ManifestValidRule(RuleBase):
def __init__(self, **kwargs):
super().__init__()
self.rule_name = "ManifestValidRule"
self.component_tool_eval_cases = kwargs.get("component_tool_eval_cases", {})

def check(self, component_cls) -> CheckInfo:
def check(self, component_obj) -> CheckInfo:
check_pass_flag = True
invalid_details = []
component_cls_name = component_cls.__name__
if component_cls_name not in self.component_tool_eval_cases:
invalid_details.append("{} 没有添加测试case到 component_tool_eval_cases 中".format(component_cls_name))
else:
component_case = self.component_tool_eval_cases[component_cls_name]()
envs = component_case.envs()
os.environ.update(envs)
init_args = component_case.init_args()

try:
component_obj = component_cls(**init_args)
if not hasattr(component_obj, "manifests"):
raise ValueError("No manifests found")
manifests = component_obj.manifests
# NOTE(暂时检查manifest中的第一个mainfest)
if not manifests or len(manifests) == 0:
raise ValueError("No manifests found")
manifest = manifests[0]
tool_name = manifest['name']
tool_desc = manifest['description']
schema = manifest["parameters"]
schema["title"] = tool_name
# 第一步,将json schema转换为pydantic模型
pydantic_model = json_schema_to_pydantic_model(schema, tool_name)
check_to_json = pydantic_model.schema_json()
json_to_dict = json.loads(check_to_json)

if "properties" in schema:
properties = schema["properties"]
for key, value in properties.items():
if "type" not in value:
invalid_details.append("\'type' must be in properties item: {}".format(key))
if "description" not in value:
invalid_details.append("\'description' must be in properties item: {}".format(key))

except Exception as e:
print(e)
check_pass_flag = False
invalid_details.append(str(e))

for env in envs:
os.environ.pop(env)
try:
if not hasattr(component_obj, "manifests"):
raise ValueError("No manifests found")
manifests = component_obj.manifests
# NOTE(暂时检查manifest中的第一个mainfest)
if not manifests or len(manifests) == 0:
raise ValueError("No manifests found")
manifest = manifests[0]
tool_name = manifest['name']
tool_desc = manifest['description']
schema = manifest["parameters"]
schema["title"] = tool_name
# 第一步,将json schema转换为pydantic模型
pydantic_model = json_schema_to_pydantic_model(schema, tool_name)
check_to_json = pydantic_model.schema_json()
json_to_dict = json.loads(check_to_json)

if "properties" in schema:
properties = schema["properties"]
for key, value in properties.items():
if "type" not in value:
invalid_details.append("\'type' must be in properties item: {}".format(key))
if "description" not in value:
invalid_details.append("\'description' must be in properties item: {}".format(key))

except Exception as e:
print(e)
check_pass_flag = False
invalid_details.append(str(e))

if len(invalid_details) > 0:
check_pass_flag = False
Expand Down Expand Up @@ -137,14 +125,14 @@ def __init__(self):
self.rule_name = "MainfestMatchToolEvalRule"


def check(self, component_cls) -> CheckInfo:
def check(self, component_obj) -> CheckInfo:
check_pass_flag = True
invalid_details = []

try:
if not hasattr(component_cls, "manifests"):
if not hasattr(component_obj, "manifests"):
raise ValueError("No manifests found")
manifests = component_cls.manifests
manifests = component_obj.manifests
# NOTE(暂时检查manifest中的第一个mainfest)
if not manifests or len(manifests) == 0:
raise ValueError("No manifests found")
Expand All @@ -158,7 +146,7 @@ def check(self, component_cls) -> CheckInfo:
# 交互检查
tool_eval_input_params = []
print("required_params: {}".format(manifest_var))
signature = inspect.signature(component_cls.tool_eval)
signature = inspect.signature(component_obj.tool_eval)
ileagal_params = []
for param_name, param in signature.parameters.items():
if param_name == 'kwargs' or param_name == 'args' or param_name == 'self':
Expand Down Expand Up @@ -193,10 +181,6 @@ def check(self, component_cls) -> CheckInfo:
check_detail=",".join(invalid_details))






class ToolEvalInputNameRule(RuleBase):
"""
检查tool_eval的输入参数中,是否包含系统保留的输入名称
Expand All @@ -222,10 +206,15 @@ def __init__(self):
"_sys_custom_variables",
"_sys_thought_model_config",
"_sys_rag_model_config",
"_sys_parent_span_id",
"_sys_span_id",
"_sys_memory",
"_sys_code_execution_endpoint",
"_sys_session_id"
]

def check(self, component_cls) -> CheckInfo:
tool_eval_signature = inspect.signature(component_cls.__init__)
def check(self, component_obj) -> CheckInfo:
tool_eval_signature = inspect.signature(component_obj.tool_eval)
params = tool_eval_signature.parameters
invalid_details = []
check_pass_flag = True
Expand All @@ -250,7 +239,6 @@ class ToolEvalOutputJsonRule(RuleBase):
def __init__(self, **kwargs):
super().__init__()
self.rule_name = 'ToolEvalOutputJsonRule'
self.component_tool_eval_cases = kwargs.get("component_tool_eval_cases")

def _check_pre_format(self, outputs):
invalid_details = []
Expand Down Expand Up @@ -351,42 +339,26 @@ def _check_text_and_code(self, component_case, output_dict):
else:
return []

def check(self, component_cls) -> CheckInfo:
def check(self, component_obj, component_case) -> CheckInfo:
invalid_details = []
component_cls_name = component_cls.__name__

if component_cls_name not in self.component_tool_eval_cases:
invalid_details.append("{} 没有添加测试case到 component_tool_eval_cases 中".format(component_cls_name))
else:
component_case = self.component_tool_eval_cases[component_cls_name]()

envs = {}
if hasattr(component_case, "envs"):
envs = component_case.envs()
os.environ.update(envs)

input_dict = component_case.inputs()
init_args = component_case.init_args()
component_obj = component_cls(**init_args)
output_json_schemas = component_case.schemas()

try:
stream_output_dict = {"text": "", "oral_text":"", "code": ""}
stream_outputs = component_obj.tool_eval(**input_dict)
for stream_output in stream_outputs:
iter_invalid_detail = self._check_jsonschema(stream_output.model_dump(), output_json_schemas)
invalid_details.extend(iter_invalid_detail)
iter_output_dict = self._gather_iter_outputs(stream_output)
stream_output_dict["text"] += iter_output_dict["text"]
stream_output_dict["oral_text"] += iter_output_dict["oral_text"]
stream_output_dict["code"] += iter_output_dict["code"]
if len(invalid_details) == 0:
invalid_details.extend(self._check_text_and_code(component_case, stream_output_dict))
except Exception as e:
invalid_details.append("ToolEval执行失败: {}".format(e))

for env in envs:
os.environ.pop(env)
input_dict = component_case.inputs()
output_json_schemas = component_case.schemas()

try:
stream_output_dict = {"text": "", "oral_text":"", "code": ""}
stream_outputs = component_obj.tool_eval(**input_dict)
for stream_output in stream_outputs:
iter_invalid_detail = self._check_jsonschema(stream_output.model_dump(), output_json_schemas)
invalid_details.extend(iter_invalid_detail)
iter_output_dict = self._gather_iter_outputs(stream_output)
stream_output_dict["text"] += iter_output_dict["text"]
stream_output_dict["oral_text"] += iter_output_dict["oral_text"]
stream_output_dict["code"] += iter_output_dict["code"]
if len(invalid_details) == 0:
invalid_details.extend(self._check_text_and_code(component_case, stream_output_dict))
except Exception as e:
invalid_details.append("ToolEval执行失败: {}".format(e))

if len(invalid_details) > 0:
return CheckInfo(
Expand All @@ -400,6 +372,88 @@ def check(self, component_cls) -> CheckInfo:
check_detail="")


def register_component_check_rule(rule_name: str, rule_cls: RuleBase, init_args={}):
def register_component_check_rule(rule_name: str, rule_cls: RuleBase):
component_checker = ComponentCheckBase()
component_checker.register_rule(rule_name, rule_cls(**init_args))
component_checker.register_rule(rule_name, rule_cls())


def check_component_with_retry(component_import_res_tuple):
"""
使用重试机制检查组件。测试用例失败后会重试两次。

Args:
component_import_res_tuple (tuple): 包含组件和导入结果的元组。

Returns:
list: 包含错误信息的数据列表。

"""
component, import_res, component_case_cls = component_import_res_tuple
component_check_base = ComponentCheckBase()

error_data = []
max_retries = 2 # 设置最大重试次数
attempts = 0

while attempts <= max_retries:
if import_res["import_error"] != "":
error_data.append({"Component Name": component, "Error Message": import_res["import_error"]})
print("组件名称:{} 错误信息:{}".format(component, import_res["import_error"]))
break

component_case = component_case_cls()
envs = component_case.envs()
os.environ.update(envs)
component_cls = import_res["obj"]
component_obj = component_cls(**component_case.init_args())

try:
# 此处的self.component_check_base.notify需要根据实际情况修改
pass_check, reasons = component_check_base.notify(component_obj, component_case) # 示例修改
reasons = list(set(reasons))
if not pass_check:
error_data.append({"Component Name": component, "Error Message": ", ".join(reasons)})
print("组件名称:{} 错误信息:{}".format(component, ", ".join(reasons)))
# 如果检查失败,增加尝试次数并重试
attempts += 1
if attempts <= max_retries:
print("组件名称:{} 将重试,当前尝试次数:{}".format(component, attempts))
continue
# 如果检查通过,则退出循环
break
except Exception as e:
error_data.append({"Component Name": component, "Error Message": str(e)})
print("组件名称:{} 错误信息:{}".format(component, str(e)))
# 如果发生异常,增加尝试次数并重试
attempts += 1
if attempts <= max_retries:
print("组件名称:{} 将重试,当前尝试次数:{}".format(component, attempts))
continue

finally:
for env in envs:
os.environ.pop(env)

return error_data

def write_error_data(txt_file_path, error_df, error_stats):
"""将组件错误信息写入文件

Args:
error_df (Union[pd.DataFrame, None]): 错误信息表格
error_stats (dict): 错误统计信息
"""
with open(txt_file_path, 'w') as file:
file.write("Component Name\tError Message\n")
for _, row in error_df.iterrows():
file.write(f"{row['Component Name']}\t{row['Error Message']}\n")
file.write("\n错误统计信息:\n")
for error, count in error_stats.items():
file.write(f"错误信息: {error}, 出现次数: {count}\n")
print(f"\n错误信息已写入: {txt_file_path}")


register_component_check_rule("ManifestValidRule", ManifestValidRule)
register_component_check_rule("MainfestMatchToolEvalRule", MainfestMatchToolEvalRule)
register_component_check_rule("ToolEvalInputNameRule", ToolEvalInputNameRule)
register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule)
Loading
Loading