Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into gemm-e2e
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Jul 2, 2024
2 parents 93e0780 + 754b3cb commit b5f4c1a
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 136 deletions.
12 changes: 12 additions & 0 deletions tests/build_and_test_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl --force-reinstall
pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl --force-reinstall
pip3 install -r $ROOT_PROJ_DIR/frontends/torch-frontend/torch-requirements.txt
pip3 install flash_attn==2.5.3

# numerical test
python3 tests/numerical_test/main.py --target all
rm -rf ./local_test

# profiler test
python3 tests/numerical_test/profiler.py $ROOT_PROJ_DIR/tests/numerical_test/mlir_tests/cpu_ops/add.mlir --target cpu
python3 tests/numerical_test/profiler.py $ROOT_PROJ_DIR/tests/numerical_test/mlir_tests/ops/add.mlir --target cuda
rm -rf ./local_profiling

# generate compitibility test
python3 tests/numerical_test/gen_brt_tests.py
rm -rf ./local_golden

popd
103 changes: 103 additions & 0 deletions tests/compatibilty_test/execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import brt
from brt.utils import brt_dtype_to_torch_dtype

import torch
import numpy as np
import os
import re

from reporting import TestResult


class BRTBackend:

def __init__(self, device, brt_file_path):
_stream = None
self.device = None
if device == "CPU":
self.session = brt.Session(device=device.upper(), )
self.device = "cpu"
_stream = None
else:
raise NotImplementedError(
f"Compatible test for {device} not implement")

self.session.load(brt_file_path)
self.req = self.session.new_request_context(_stream)

def _check(self, result, golden, atol=1e-06):
return torch.allclose(result, golden, atol=atol)

def _generate_torch_outputs(self):
outputs = []
for offset in self.session.get_output_arg_offsets():
outputs.append(
torch.empty(self.session.get_static_shape(offset),
dtype=brt_dtype_to_torch_dtype(
self.session.get_data_type(offset)),
device=self.device))
return outputs

def compare(self, inputs, goldens):
outputs = self._generate_torch_outputs()
assert len(self.session.get_input_arg_offsets()) == len(inputs)
assert len(self.session.get_output_arg_offsets()) == len(outputs)
assert len(outputs) == len(goldens)
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(
arg.shape)
self.req.bind_arg(offset, arg.data_ptr())
for offset, ret in zip(self.session.get_output_arg_offsets(), outputs):
assert list(self.session.get_static_shape(offset)) == list(
ret.shape)
self.req.bind_arg(offset, ret.data_ptr())
self.req.finish_io_binding()
self.req.run()
self.req.sync()
return all(self._check(o, g) for o, g in zip(outputs, goldens))


def run_and_check_mlir(target, name, inp_files, out_files, byre_file):

_device = None
if target == "cpu":
_device = "CPU"

brt_backend = BRTBackend(device=_device, brt_file_path=byre_file)

cmp_res = []
for idx, (input_file, target_file) in enumerate(zip(inp_files, out_files)):
inp = np.load(input_file, allow_pickle=True)
inp = [
torch.from_numpy(inp[f]).contiguous().to(_device.lower())
for f in inp.files
]
tgt = np.load(target_file, allow_pickle=True)
tgt = [
torch.from_numpy(tgt[f]).contiguous().to(_device.lower())
for f in tgt.files
]
if brt_backend.compare(inp, tgt):
cmp_res.append(TestResult(name + str(idx), numerical_error=None))
else:
cmp_res.append(
TestResult(
name + str(idx),
numerical_error=
f"input is {input_file}, output not match {target_file}"))

return cmp_res
86 changes: 86 additions & 0 deletions tests/compatibilty_test/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

#!/usr/bin/python3
import argparse
import os
import sys
import json

from reporting import report_results

from execute import run_and_check_mlir
"""
Usage:
This directory implements the code for compatibilty test framework. One should pass a test dir which contains:
(1) subdirs for each tese case and json conf file named `testcase.json`
(2) byre compilation artifacts named as {model_name}/{model_name}.rt.mlir
(3) several inputs and goldens named as inputs.{num}.npz and outputs.{num}.npz
"""


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--testdir",
type=str,
default=None,
help="Directory has test cases",
)
args = parser.parse_args()
return args


def run(testdir):

def extract_name_from_tesrdir(testdir):
return os.path.basename(testdir)

result = []
conf_file = os.path.join(testdir, "testcase.json")
if not os.path.exists(conf_file):
raise RuntimeError(f"test case config file {conf_file} not found")
with open(conf_file, "r", encoding='utf-8') as f:
json_conf = json.load(f)
for target, data in json_conf.items():
for byre_version, cases in data.items():
for name, files in cases.items():
input_files = files["golden_inputs"]
input_files = [os.path.join(testdir, f) for f in input_files]
golden_files = files["golden_outputs"]
golden_files = [os.path.join(testdir, f) for f in golden_files]
byre_file = files["brt_entry_file"]
byre_file = os.path.join(testdir, byre_file)
if len(input_files) != len(golden_files):
raise RuntimeError(
f"num of inouts({len(input_files)}) and goldens({len(golden_files)}) not eq in {name}"
)
if not os.path.exists(byre_file):
raise RuntimeError(f"byre file{byre_file} not found")
result += run_and_check_mlir(target, name, input_files,
golden_files, byre_file)
return result


def main():
args = parse_args()

results = run(args.testdir)

failed = report_results(results)
sys.exit(1 if failed else 0)


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions tests/compatibilty_test/reporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import NamedTuple, Optional, List


class TestResult(NamedTuple):
unique_name: str
numerical_error: Optional[str]


def report_results(results: List[TestResult]):
fail_case = []
pass_case = []
for result in results:
if result.numerical_error is not None:
fail_case.append([
result.unique_name, "numerical failed: " + result.unique_name +
"\n" + result.numerical_error
])
else:
pass_case.append(result)
pass_case.sort(key=lambda x: x.unique_name)
fail_case.sort(key=lambda x: x[0])

print(f"\n****** PASS tests - {len(pass_case)} tests")
for test in pass_case:
print(test.unique_name, " --- PASS")
for test in fail_case:
print(test[1])
print(f"\n****** FAILED tests - {len(fail_case)} tests")
for test in fail_case:
print(test[0])
return len(fail_case) > 0
34 changes: 22 additions & 12 deletions tests/numerical_test/gen_brt_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import sys
import shutil
import traceback
import json

import numpy as np
from execute import MLIRDataGenerator
from reporting import TestResult, report_results
from testset import CPU_MLIR_TEST_DIR, CPU_MLIR_TEST_SET, CPU_XFAIL_SET
import byteir

parser = argparse.ArgumentParser()
Expand All @@ -47,11 +49,6 @@
help="Byre serialization target version")
args = parser.parse_args()

# Unsupported ops
EXCLUDE_MLIR_CPU_TESTS = [
"custom_call_tf_UpperBound.mlir",
"rng.mlir",
]

def gen_golden_mlir(mhlo_file, target, golden_dir, num=2):
"""
Expand All @@ -66,6 +63,8 @@ def save_np_data(fpath: str, data):

file_base_name = os.path.basename(mhlo_file).split(".")[0]
unique_name = file_base_name + "." + target
json_relative_dir_path = "./" + os.path.basename(golden_dir) + "/" + unique_name
json_result = {unique_name : {}}
try:
data_generator = MLIRDataGenerator(mhlo_file, target)
func_name = data_generator.entry_func_name
Expand All @@ -77,6 +76,8 @@ def save_np_data(fpath: str, data):
if data_generator.need_special_inputs():
num = 1

input_file_path = []
output_file_path = []
for idx in range(0, num):
np_inputs = data_generator.generate_np_inputs()

Expand All @@ -88,11 +89,16 @@ def save_np_data(fpath: str, data):
# dump to local file
save_np_data(WORK_FOLDER + f"/inputs.{str(idx)}.npz", np_inputs)
save_np_data(WORK_FOLDER + f"/outputs.{str(idx)}.npz", golden_outputs)
input_file_path.append(json_relative_dir_path + f"/inputs.{str(idx)}.npz")
output_file_path.append(json_relative_dir_path + f"/outputs.{str(idx)}.npz")

del np_inputs, golden_outputs
json_result[unique_name].update({"golden_inputs": input_file_path})
json_result[unique_name].update({"golden_outputs": output_file_path})

# byteir compile
output_mlir_file_name = f"{WORK_FOLDER}/{unique_name}.rt.mlirbc"
json_result[unique_name].update({"brt_entry_file" : json_relative_dir_path + f"/{unique_name}.rt.mlirbc"})
byteir.compile(
mhlo_file, output_mlir_file_name, entry_func=func_name, target=target
)
Expand All @@ -116,7 +122,7 @@ def save_np_data(fpath: str, data):
runtime_error=None,
numerical_error=None,
performance_result=None,
)
), None

res = TestResult(
unique_name=unique_name,
Expand All @@ -126,13 +132,11 @@ def save_np_data(fpath: str, data):
performance_result=None,
)

return res

return res, json_result


def gen_mlir_cpu_golden():
directory = os.path.dirname(os.path.realpath(__file__))
directory = directory + "/mlir_tests/cpu_ops"
directory = CPU_MLIR_TEST_DIR
cpu_target = "cpu"
os.makedirs(args.output_dir, exist_ok=True)
golden_dir = f"{args.output_dir}/CPU_BYRE_{args.byre_serial_version.replace('.', '_')}"
Expand All @@ -146,16 +150,22 @@ def gen_mlir_cpu_golden():
continue
f = os.path.join(directory, filename)
# checking if it is a file
if os.path.isfile(f) and filename not in EXCLUDE_MLIR_CPU_TESTS:
if os.path.isfile(f) and filename in (CPU_MLIR_TEST_SET - CPU_XFAIL_SET):
mlir_tests.append(f)

results = []
byre_version_str = f"byre{args.byre_serial_version}"
json_results = {"cpu" : {byre_version_str : {}}}
for test in mlir_tests:
fpath = test
res = gen_golden_mlir(fpath,
res, key_value = gen_golden_mlir(fpath,
cpu_target,
golden_dir)
results.append(res)
if key_value is not None:
json_results["cpu"][byre_version_str].update(key_value)
with open(f"{args.output_dir}/testcase.json", 'w') as f:
json.dump(json_results, f, indent=4)
return results


Expand Down
Loading

0 comments on commit b5f4c1a

Please sign in to comment.