Skip to content

Commit

Permalink
[e2e] add compatible test (#381)
Browse files Browse the repository at this point in the history
- as title

usage:
```
python main.py  --testdir  your_testcase_dir
```
  • Loading branch information
YellowHCH authored Jul 2, 2024
1 parent a6fe5ec commit 754b3cb
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 0 deletions.
103 changes: 103 additions & 0 deletions tests/compatibilty_test/execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import brt
from brt.utils import brt_dtype_to_torch_dtype

import torch
import numpy as np
import os
import re

from reporting import TestResult


class BRTBackend:

def __init__(self, device, brt_file_path):
_stream = None
self.device = None
if device == "CPU":
self.session = brt.Session(device=device.upper(), )
self.device = "cpu"
_stream = None
else:
raise NotImplementedError(
f"Compatible test for {device} not implement")

self.session.load(brt_file_path)
self.req = self.session.new_request_context(_stream)

def _check(self, result, golden, atol=1e-06):
return torch.allclose(result, golden, atol=atol)

def _generate_torch_outputs(self):
outputs = []
for offset in self.session.get_output_arg_offsets():
outputs.append(
torch.empty(self.session.get_static_shape(offset),
dtype=brt_dtype_to_torch_dtype(
self.session.get_data_type(offset)),
device=self.device))
return outputs

def compare(self, inputs, goldens):
outputs = self._generate_torch_outputs()
assert len(self.session.get_input_arg_offsets()) == len(inputs)
assert len(self.session.get_output_arg_offsets()) == len(outputs)
assert len(outputs) == len(goldens)
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(
arg.shape)
self.req.bind_arg(offset, arg.data_ptr())
for offset, ret in zip(self.session.get_output_arg_offsets(), outputs):
assert list(self.session.get_static_shape(offset)) == list(
ret.shape)
self.req.bind_arg(offset, ret.data_ptr())
self.req.finish_io_binding()
self.req.run()
self.req.sync()
return all(self._check(o, g) for o, g in zip(outputs, goldens))


def run_and_check_mlir(target, name, inp_files, out_files, byre_file):

_device = None
if target == "cpu":
_device = "CPU"

brt_backend = BRTBackend(device=_device, brt_file_path=byre_file)

cmp_res = []
for idx, (input_file, target_file) in enumerate(zip(inp_files, out_files)):
inp = np.load(input_file, allow_pickle=True)
inp = [
torch.from_numpy(inp[f]).contiguous().to(_device.lower())
for f in inp.files
]
tgt = np.load(target_file, allow_pickle=True)
tgt = [
torch.from_numpy(tgt[f]).contiguous().to(_device.lower())
for f in tgt.files
]
if brt_backend.compare(inp, tgt):
cmp_res.append(TestResult(name + str(idx), numerical_error=None))
else:
cmp_res.append(
TestResult(
name + str(idx),
numerical_error=
f"input is {input_file}, output not match {target_file}"))

return cmp_res
86 changes: 86 additions & 0 deletions tests/compatibilty_test/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

#!/usr/bin/python3
import argparse
import os
import sys
import json

from reporting import report_results

from execute import run_and_check_mlir
"""
Usage:
This directory implements the code for compatibilty test framework. One should pass a test dir which contains:
(1) subdirs for each tese case and json conf file named `testcase.json`
(2) byre compilation artifacts named as {model_name}/{model_name}.rt.mlir
(3) several inputs and goldens named as inputs.{num}.npz and outputs.{num}.npz
"""


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--testdir",
type=str,
default=None,
help="Directory has test cases",
)
args = parser.parse_args()
return args


def run(testdir):

def extract_name_from_tesrdir(testdir):
return os.path.basename(testdir)

result = []
conf_file = os.path.join(testdir, "testcase.json")
if not os.path.exists(conf_file):
raise RuntimeError(f"test case config file {conf_file} not found")
with open(conf_file, "r", encoding='utf-8') as f:
json_conf = json.load(f)
for target, data in json_conf.items():
for byre_version, cases in data.items():
for name, files in cases.items():
input_files = files["golden_inputs"]
input_files = [os.path.join(testdir, f) for f in input_files]
golden_files = files["golden_outputs"]
golden_files = [os.path.join(testdir, f) for f in golden_files]
byre_file = files["brt_entry_file"]
byre_file = os.path.join(testdir, byre_file)
if len(input_files) != len(golden_files):
raise RuntimeError(
f"num of inouts({len(input_files)}) and goldens({len(golden_files)}) not eq in {name}"
)
if not os.path.exists(byre_file):
raise RuntimeError(f"byre file{byre_file} not found")
result += run_and_check_mlir(target, name, input_files,
golden_files, byre_file)
return result


def main():
args = parse_args()

results = run(args.testdir)

failed = report_results(results)
sys.exit(1 if failed else 0)


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions tests/compatibilty_test/reporting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import NamedTuple, Optional, List


class TestResult(NamedTuple):
unique_name: str
numerical_error: Optional[str]


def report_results(results: List[TestResult]):
fail_case = []
pass_case = []
for result in results:
if result.numerical_error is not None:
fail_case.append([
result.unique_name, "numerical failed: " + result.unique_name +
"\n" + result.numerical_error
])
else:
pass_case.append(result)
pass_case.sort(key=lambda x: x.unique_name)
fail_case.sort(key=lambda x: x[0])

print(f"\n****** PASS tests - {len(pass_case)} tests")
for test in pass_case:
print(test.unique_name, " --- PASS")
for test in fail_case:
print(test[1])
print(f"\n****** FAILED tests - {len(fail_case)} tests")
for test in fail_case:
print(test[0])
return len(fail_case) > 0

0 comments on commit 754b3cb

Please sign in to comment.