From a6fe5ec95f024ab9cdb3701da74241b120f360e9 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 2 Jul 2024 20:59:37 +0800 Subject: [PATCH 1/2] [e2e] use testset.py to specify testset and testing more in e2e (#384) * use `testset.py` to specify testset * testing `profiler.py` and `gen_brt_tests.py` in e2e * generate `testcase.json` in `gen_brt_tests.py` --- tests/build_and_test_e2e.sh | 12 ++++ tests/numerical_test/gen_brt_tests.py | 34 ++++++---- tests/numerical_test/main.py | 91 +------------------------- tests/numerical_test/profiler.py | 4 +- tests/numerical_test/testset.py | 93 +++++++++++++++++++++++++++ 5 files changed, 131 insertions(+), 103 deletions(-) create mode 100644 tests/numerical_test/testset.py diff --git a/tests/build_and_test_e2e.sh b/tests/build_and_test_e2e.sh index e6f7cf624..5d4d3ef35 100755 --- a/tests/build_and_test_e2e.sh +++ b/tests/build_and_test_e2e.sh @@ -21,6 +21,18 @@ pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl --force-reinstall pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl --force-reinstall pip3 install -r $ROOT_PROJ_DIR/frontends/torch-frontend/torch-requirements.txt pip3 install flash_attn==2.5.3 + +# numerical test python3 tests/numerical_test/main.py --target all rm -rf ./local_test + +# profiler test +python3 tests/numerical_test/profiler.py $ROOT_PROJ_DIR/tests/numerical_test/mlir_tests/cpu_ops/add.mlir --target cpu +python3 tests/numerical_test/profiler.py $ROOT_PROJ_DIR/tests/numerical_test/mlir_tests/ops/add.mlir --target cuda +rm -rf ./local_profiling + +# generate compitibility test +python3 tests/numerical_test/gen_brt_tests.py +rm -rf ./local_golden + popd diff --git a/tests/numerical_test/gen_brt_tests.py b/tests/numerical_test/gen_brt_tests.py index 3b90e4562..5369eae0f 100644 --- a/tests/numerical_test/gen_brt_tests.py +++ b/tests/numerical_test/gen_brt_tests.py @@ -19,10 +19,12 @@ import sys import shutil import traceback +import json import numpy as np from execute import MLIRDataGenerator from reporting import TestResult, report_results +from testset import CPU_MLIR_TEST_DIR, CPU_MLIR_TEST_SET, CPU_XFAIL_SET import byteir parser = argparse.ArgumentParser() @@ -47,11 +49,6 @@ help="Byre serialization target version") args = parser.parse_args() -# Unsupported ops -EXCLUDE_MLIR_CPU_TESTS = [ - "custom_call_tf_UpperBound.mlir", - "rng.mlir", -] def gen_golden_mlir(mhlo_file, target, golden_dir, num=2): """ @@ -66,6 +63,8 @@ def save_np_data(fpath: str, data): file_base_name = os.path.basename(mhlo_file).split(".")[0] unique_name = file_base_name + "." + target + json_relative_dir_path = "./" + os.path.basename(golden_dir) + "/" + unique_name + json_result = {unique_name : {}} try: data_generator = MLIRDataGenerator(mhlo_file, target) func_name = data_generator.entry_func_name @@ -77,6 +76,8 @@ def save_np_data(fpath: str, data): if data_generator.need_special_inputs(): num = 1 + input_file_path = [] + output_file_path = [] for idx in range(0, num): np_inputs = data_generator.generate_np_inputs() @@ -88,11 +89,16 @@ def save_np_data(fpath: str, data): # dump to local file save_np_data(WORK_FOLDER + f"/inputs.{str(idx)}.npz", np_inputs) save_np_data(WORK_FOLDER + f"/outputs.{str(idx)}.npz", golden_outputs) + input_file_path.append(json_relative_dir_path + f"/inputs.{str(idx)}.npz") + output_file_path.append(json_relative_dir_path + f"/outputs.{str(idx)}.npz") del np_inputs, golden_outputs + json_result[unique_name].update({"golden_inputs": input_file_path}) + json_result[unique_name].update({"golden_outputs": output_file_path}) # byteir compile output_mlir_file_name = f"{WORK_FOLDER}/{unique_name}.rt.mlirbc" + json_result[unique_name].update({"brt_entry_file" : json_relative_dir_path + f"/{unique_name}.rt.mlirbc"}) byteir.compile( mhlo_file, output_mlir_file_name, entry_func=func_name, target=target ) @@ -116,7 +122,7 @@ def save_np_data(fpath: str, data): runtime_error=None, numerical_error=None, performance_result=None, - ) + ), None res = TestResult( unique_name=unique_name, @@ -126,13 +132,11 @@ def save_np_data(fpath: str, data): performance_result=None, ) - return res - + return res, json_result def gen_mlir_cpu_golden(): - directory = os.path.dirname(os.path.realpath(__file__)) - directory = directory + "/mlir_tests/cpu_ops" + directory = CPU_MLIR_TEST_DIR cpu_target = "cpu" os.makedirs(args.output_dir, exist_ok=True) golden_dir = f"{args.output_dir}/CPU_BYRE_{args.byre_serial_version.replace('.', '_')}" @@ -146,16 +150,22 @@ def gen_mlir_cpu_golden(): continue f = os.path.join(directory, filename) # checking if it is a file - if os.path.isfile(f) and filename not in EXCLUDE_MLIR_CPU_TESTS: + if os.path.isfile(f) and filename in (CPU_MLIR_TEST_SET - CPU_XFAIL_SET): mlir_tests.append(f) results = [] + byre_version_str = f"byre{args.byre_serial_version}" + json_results = {"cpu" : {byre_version_str : {}}} for test in mlir_tests: fpath = test - res = gen_golden_mlir(fpath, + res, key_value = gen_golden_mlir(fpath, cpu_target, golden_dir) results.append(res) + if key_value is not None: + json_results["cpu"][byre_version_str].update(key_value) + with open(f"{args.output_dir}/testcase.json", 'w') as f: + json.dump(json_results, f, indent=4) return results diff --git a/tests/numerical_test/main.py b/tests/numerical_test/main.py index 078aae682..a5103459e 100644 --- a/tests/numerical_test/main.py +++ b/tests/numerical_test/main.py @@ -24,95 +24,8 @@ GLOBAL_TORCH_TEST_REGISTRY, GLOBAL_TORCH_TEST_REGISTRY_NAMES, ) -from torch_e2e_testing.test_suite import register_all_torch_tests - -register_all_torch_tests() - -CUR_DIR = os.path.dirname(os.path.abspath(__file__)) - - -def _get_test_files_from_dir(directory): - test_files = [] - for filename in os.listdir(directory): - if filename.startswith("."): - continue - if os.path.isfile(os.path.join(directory, filename)): - test_files.append(filename) - return test_files - - -##### CPU TEST SET ####### -CPU_MLIR_TEST_DIR = os.path.join(CUR_DIR, "mlir_tests", "cpu_ops") -CPU_MLIR_TEST_SET = set(_get_test_files_from_dir(CPU_MLIR_TEST_DIR)) -CPU_TORCH_TEST_SET = set() -CPU_XFAIL_SET = { - "custom_call_tf_UpperBound.mlir", - "rng.mlir", -} - -CPU_ALL_SET = (CPU_MLIR_TEST_SET | CPU_TORCH_TEST_SET) - CPU_XFAIL_SET - -##### CUDA TEST SET ####### -CUDA_MLIR_TEST_DIR = os.path.join(CUR_DIR, "mlir_tests", "ops") -CUDA_MLIR_TEST_SET = set(_get_test_files_from_dir(CUDA_MLIR_TEST_DIR)) -CUDA_TORCH_TEST_SET = set(GLOBAL_TORCH_TEST_REGISTRY_NAMES) -CUDA_XFAIL_SET = { - "bmm_rcr.mlir", - "bmm_rrc.mlir", - "bmm_rrr_add_f16.mlir", - "bmm_rrr_f16.mlir", - "bmm_rrr_permute_f16.mlir", - "bmm_rrr_permute_f32.mlir", - "layernorm.mlir", - "softmax.mlir", - "transpose102.mlir", - "transpose1023.mlir", - "transpose120.mlir", - "transpose1203.mlir", - "transpose2013.mlir", - "transpose120.mlir", -} - -CUDA_ALL_SET = (CUDA_MLIR_TEST_SET | CUDA_TORCH_TEST_SET) - CUDA_XFAIL_SET - -##### CUDA AIT TEST SET ####### -CUDA_AIT_MLIR_TEST_SET = { - "bmm_rcr.mlir", - "bmm_rrc.mlir", - "bmm_rrr_add_f16.mlir", - "bmm_rrr_f16.mlir", - "bmm_rrr_permute_f16.mlir", - "bmm_rrr_permute_f32.mlir", - "gemm_crr_f16.mlir", - "gemm_rrr_f16.mlir", - "gemm_rrr_f32.mlir", - "layernorm.mlir", - "softmax.mlir", - "transpose2d.mlir", - "transpose102.mlir", - "transpose1023.mlir", - "transpose120.mlir", - "transpose1203.mlir", - "transpose2013.mlir", - "transpose120.mlir", -} -CUDA_AIT_TORCH_TEST_SET = { - "MatmulF16Module_basic", - "MatmulTransposeModule_basic", - "MatmulF32Module_basic", - "BatchMatmulF32Module_basic", - "BatchMatmulAddF32Module_basic", -} -CUDA_AIT_SM80PLUS_SET = { - "gemm_rrr_f32.mlir", - "bmm_rrr_permute_f16.mlir", - "bmm_rrr_permute_f32.mlir", - "MatmulF32Module_basic", - "BatchMatmulF32Module_basic", - "BatchMatmulAddF32Module_basic", -} - -CUDA_AIT_ALL_SET = CUDA_AIT_MLIR_TEST_SET | CUDA_AIT_TORCH_TEST_SET +from testset import CPU_MLIR_TEST_DIR, CUDA_MLIR_TEST_DIR +from testset import CPU_ALL_SET, CUDA_ALL_SET, CUDA_AIT_ALL_SET, CUDA_AIT_SM80PLUS_SET ##### TEST SET CONFIG ####### TEST_SET = { diff --git a/tests/numerical_test/profiler.py b/tests/numerical_test/profiler.py index d1de5025f..df63893fb 100644 --- a/tests/numerical_test/profiler.py +++ b/tests/numerical_test/profiler.py @@ -21,8 +21,8 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("input_mlir_path") - parser.add_argument("--workdir", type=str, default="./profiling", help="workspace directory") - parser.add_argument("--name", type=str, default="model") + parser.add_argument("--workdir", type=str, default="./local_profiling", help="workspace directory") + parser.add_argument("--name", type=str, default=None) parser.add_argument("--target", type=str, default="cuda", choices=["cpu", "cuda", "cuda_with_ait"]) parser.add_argument("--mode", type=str, default="profile", choices=["numerical", "profile"]) parser.add_argument("-v", "--verbose", default=False, action="store_true") diff --git a/tests/numerical_test/testset.py b/tests/numerical_test/testset.py new file mode 100644 index 000000000..02585154e --- /dev/null +++ b/tests/numerical_test/testset.py @@ -0,0 +1,93 @@ +import os +from torch_e2e_testing.registry import ( + GLOBAL_TORCH_TEST_REGISTRY, + GLOBAL_TORCH_TEST_REGISTRY_NAMES, +) +from torch_e2e_testing.test_suite import register_all_torch_tests + +register_all_torch_tests() + +CUR_DIR = os.path.dirname(os.path.abspath(__file__)) + +def _get_test_files_from_dir(directory): + test_files = [] + for filename in os.listdir(directory): + if filename.startswith("."): + continue + if os.path.isfile(os.path.join(directory, filename)): + test_files.append(filename) + return test_files + + +##### CPU TEST SET ####### +CPU_MLIR_TEST_DIR = os.path.join(CUR_DIR, "mlir_tests", "cpu_ops") +CPU_MLIR_TEST_SET = set(_get_test_files_from_dir(CPU_MLIR_TEST_DIR)) +CPU_TORCH_TEST_SET = set() +CPU_XFAIL_SET = { + "custom_call_tf_UpperBound.mlir", + "rng.mlir", +} + +CPU_ALL_SET = (CPU_MLIR_TEST_SET | CPU_TORCH_TEST_SET) - CPU_XFAIL_SET + +##### CUDA TEST SET ####### +CUDA_MLIR_TEST_DIR = os.path.join(CUR_DIR, "mlir_tests", "ops") +CUDA_MLIR_TEST_SET = set(_get_test_files_from_dir(CUDA_MLIR_TEST_DIR)) +CUDA_TORCH_TEST_SET = set(GLOBAL_TORCH_TEST_REGISTRY_NAMES) +CUDA_XFAIL_SET = { + "bmm_rcr.mlir", + "bmm_rrc.mlir", + "bmm_rrr_add_f16.mlir", + "bmm_rrr_f16.mlir", + "bmm_rrr_permute_f16.mlir", + "bmm_rrr_permute_f32.mlir", + "layernorm.mlir", + "softmax.mlir", + "transpose102.mlir", + "transpose1023.mlir", + "transpose120.mlir", + "transpose1203.mlir", + "transpose2013.mlir", + "transpose120.mlir", +} + +CUDA_ALL_SET = (CUDA_MLIR_TEST_SET | CUDA_TORCH_TEST_SET) - CUDA_XFAIL_SET + +##### CUDA AIT TEST SET ####### +CUDA_AIT_MLIR_TEST_SET = { + "bmm_rcr.mlir", + "bmm_rrc.mlir", + "bmm_rrr_add_f16.mlir", + "bmm_rrr_f16.mlir", + "bmm_rrr_permute_f16.mlir", + "bmm_rrr_permute_f32.mlir", + "gemm_crr_f16.mlir", + "gemm_rrr_f16.mlir", + "gemm_rrr_f32.mlir", + "layernorm.mlir", + "softmax.mlir", + "transpose2d.mlir", + "transpose102.mlir", + "transpose1023.mlir", + "transpose120.mlir", + "transpose1203.mlir", + "transpose2013.mlir", + "transpose120.mlir", +} +CUDA_AIT_TORCH_TEST_SET = { + "MatmulF16Module_basic", + "MatmulTransposeModule_basic", + "MatmulF32Module_basic", + "BatchMatmulF32Module_basic", + "BatchMatmulAddF32Module_basic", +} +CUDA_AIT_SM80PLUS_SET = { + "gemm_rrr_f32.mlir", + "bmm_rrr_permute_f16.mlir", + "bmm_rrr_permute_f32.mlir", + "MatmulF32Module_basic", + "BatchMatmulF32Module_basic", + "BatchMatmulAddF32Module_basic", +} + +CUDA_AIT_ALL_SET = CUDA_AIT_MLIR_TEST_SET | CUDA_AIT_TORCH_TEST_SET From 754b3cb104e8305b63ca89b9737f096d05549311 Mon Sep 17 00:00:00 2001 From: Chenhui Huang Date: Tue, 2 Jul 2024 23:51:10 +0800 Subject: [PATCH 2/2] [e2e] add compatible test (#381) - as title usage: ``` python main.py --testdir your_testcase_dir ``` --- tests/compatibilty_test/execute.py | 103 +++++++++++++++++++++++++++ tests/compatibilty_test/main.py | 86 ++++++++++++++++++++++ tests/compatibilty_test/reporting.py | 45 ++++++++++++ 3 files changed, 234 insertions(+) create mode 100644 tests/compatibilty_test/execute.py create mode 100644 tests/compatibilty_test/main.py create mode 100644 tests/compatibilty_test/reporting.py diff --git a/tests/compatibilty_test/execute.py b/tests/compatibilty_test/execute.py new file mode 100644 index 000000000..4e50b6cf3 --- /dev/null +++ b/tests/compatibilty_test/execute.py @@ -0,0 +1,103 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import brt +from brt.utils import brt_dtype_to_torch_dtype + +import torch +import numpy as np +import os +import re + +from reporting import TestResult + + +class BRTBackend: + + def __init__(self, device, brt_file_path): + _stream = None + self.device = None + if device == "CPU": + self.session = brt.Session(device=device.upper(), ) + self.device = "cpu" + _stream = None + else: + raise NotImplementedError( + f"Compatible test for {device} not implement") + + self.session.load(brt_file_path) + self.req = self.session.new_request_context(_stream) + + def _check(self, result, golden, atol=1e-06): + return torch.allclose(result, golden, atol=atol) + + def _generate_torch_outputs(self): + outputs = [] + for offset in self.session.get_output_arg_offsets(): + outputs.append( + torch.empty(self.session.get_static_shape(offset), + dtype=brt_dtype_to_torch_dtype( + self.session.get_data_type(offset)), + device=self.device)) + return outputs + + def compare(self, inputs, goldens): + outputs = self._generate_torch_outputs() + assert len(self.session.get_input_arg_offsets()) == len(inputs) + assert len(self.session.get_output_arg_offsets()) == len(outputs) + assert len(outputs) == len(goldens) + for offset, arg in zip(self.session.get_input_arg_offsets(), inputs): + assert list(self.session.get_static_shape(offset)) == list( + arg.shape) + self.req.bind_arg(offset, arg.data_ptr()) + for offset, ret in zip(self.session.get_output_arg_offsets(), outputs): + assert list(self.session.get_static_shape(offset)) == list( + ret.shape) + self.req.bind_arg(offset, ret.data_ptr()) + self.req.finish_io_binding() + self.req.run() + self.req.sync() + return all(self._check(o, g) for o, g in zip(outputs, goldens)) + + +def run_and_check_mlir(target, name, inp_files, out_files, byre_file): + + _device = None + if target == "cpu": + _device = "CPU" + + brt_backend = BRTBackend(device=_device, brt_file_path=byre_file) + + cmp_res = [] + for idx, (input_file, target_file) in enumerate(zip(inp_files, out_files)): + inp = np.load(input_file, allow_pickle=True) + inp = [ + torch.from_numpy(inp[f]).contiguous().to(_device.lower()) + for f in inp.files + ] + tgt = np.load(target_file, allow_pickle=True) + tgt = [ + torch.from_numpy(tgt[f]).contiguous().to(_device.lower()) + for f in tgt.files + ] + if brt_backend.compare(inp, tgt): + cmp_res.append(TestResult(name + str(idx), numerical_error=None)) + else: + cmp_res.append( + TestResult( + name + str(idx), + numerical_error= + f"input is {input_file}, output not match {target_file}")) + + return cmp_res diff --git a/tests/compatibilty_test/main.py b/tests/compatibilty_test/main.py new file mode 100644 index 000000000..70f31f6bc --- /dev/null +++ b/tests/compatibilty_test/main.py @@ -0,0 +1,86 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +#!/usr/bin/python3 +import argparse +import os +import sys +import json + +from reporting import report_results + +from execute import run_and_check_mlir +""" +Usage: + This directory implements the code for compatibilty test framework. One should pass a test dir which contains: + (1) subdirs for each tese case and json conf file named `testcase.json` + (2) byre compilation artifacts named as {model_name}/{model_name}.rt.mlir + (3) several inputs and goldens named as inputs.{num}.npz and outputs.{num}.npz +""" + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--testdir", + type=str, + default=None, + help="Directory has test cases", + ) + args = parser.parse_args() + return args + + +def run(testdir): + + def extract_name_from_tesrdir(testdir): + return os.path.basename(testdir) + + result = [] + conf_file = os.path.join(testdir, "testcase.json") + if not os.path.exists(conf_file): + raise RuntimeError(f"test case config file {conf_file} not found") + with open(conf_file, "r", encoding='utf-8') as f: + json_conf = json.load(f) + for target, data in json_conf.items(): + for byre_version, cases in data.items(): + for name, files in cases.items(): + input_files = files["golden_inputs"] + input_files = [os.path.join(testdir, f) for f in input_files] + golden_files = files["golden_outputs"] + golden_files = [os.path.join(testdir, f) for f in golden_files] + byre_file = files["brt_entry_file"] + byre_file = os.path.join(testdir, byre_file) + if len(input_files) != len(golden_files): + raise RuntimeError( + f"num of inouts({len(input_files)}) and goldens({len(golden_files)}) not eq in {name}" + ) + if not os.path.exists(byre_file): + raise RuntimeError(f"byre file{byre_file} not found") + result += run_and_check_mlir(target, name, input_files, + golden_files, byre_file) + return result + + +def main(): + args = parse_args() + + results = run(args.testdir) + + failed = report_results(results) + sys.exit(1 if failed else 0) + + +if __name__ == "__main__": + main() diff --git a/tests/compatibilty_test/reporting.py b/tests/compatibilty_test/reporting.py new file mode 100644 index 000000000..0ec14108d --- /dev/null +++ b/tests/compatibilty_test/reporting.py @@ -0,0 +1,45 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import NamedTuple, Optional, List + + +class TestResult(NamedTuple): + unique_name: str + numerical_error: Optional[str] + + +def report_results(results: List[TestResult]): + fail_case = [] + pass_case = [] + for result in results: + if result.numerical_error is not None: + fail_case.append([ + result.unique_name, "numerical failed: " + result.unique_name + + "\n" + result.numerical_error + ]) + else: + pass_case.append(result) + pass_case.sort(key=lambda x: x.unique_name) + fail_case.sort(key=lambda x: x[0]) + + print(f"\n****** PASS tests - {len(pass_case)} tests") + for test in pass_case: + print(test.unique_name, " --- PASS") + for test in fail_case: + print(test[1]) + print(f"\n****** FAILED tests - {len(fail_case)} tests") + for test in fail_case: + print(test[0]) + return len(fail_case) > 0