Skip to content

Commit

Permalink
[e2e] refactor generating of compatibility tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Jul 1, 2024
1 parent 7366077 commit 1daeb2f
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 113 deletions.
85 changes: 8 additions & 77 deletions tests/numerical_test/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# limitations under the License.
# ==============================================================================

from torch_e2e_testing.framework import generate_golden_trace
from reporting import TestResult

import brt
Expand All @@ -24,7 +23,6 @@
import torch
import numpy as np
import os
import shutil
import traceback
import time
from typing import List
Expand Down Expand Up @@ -55,7 +53,11 @@ def entry_func(self):

@property
def entry_func_name(self) -> str:
return self.entry_func.name.value
return self.entry_func.name.value

def need_special_inputs(self) -> bool:
key = self.target + "@" + self.file_base_name
return key in MLIR_TEST_SPECIAL_INPUTS

def generate_np_inputs(self) -> List[np.ndarray]:
key = self.target + "@" + self.file_base_name
Expand Down Expand Up @@ -98,78 +100,6 @@ def generate_torch_outputs(self, device="cpu") -> List[torch.Tensor]:
return outputs


def gen_golden_mlir(mhlo_file, target, **kwargs):
"""
Arguements:
@param mhlo_file: Source stablehlo/mhlo file.
@param target: Target name like `cpu`,`cuda`
@param num: Numbers of generated golden in/output, default to 5.
@param mode: The data distribution of inputs.
@param low/hing: The range of generated inputs data.
"""

def save_np_data(fpath: str, data):
np.save(fpath, data)

try:
data_generator = MLIRDataGenerator(mhlo_file, target)
func_name = data_generator.entry_func_name
unique_name = os.path.basename(mhlo_file).split(".")[0]
unique_name = unique_name + "." + target
iter_number = kwargs["num"] if "num" in kwargs else 5

WORK_FOLDER = kwargs["golden_dir"] if "golden_dir" in kwargs else "./local_test"
WORK_FOLDER = WORK_FOLDER + f"/{unique_name}"
os.makedirs(WORK_FOLDER, exist_ok=True)

for idx in range(0, iter_number):
np_inputs = data_generator.generate_np_inputs()

# run golden
from mhlo_tools.ir_executor import Interpreter
interp = Interpreter.load_from_file(mhlo_file, is_stablehlo=True)
golden_outputs = interp.call_function(func_name, np_inputs)

# dump to local file
save_np_data(WORK_FOLDER + f"/input_{str(idx)}.npy", np_inputs)
save_np_data(WORK_FOLDER + f"/output_{str(idx)}.npy", golden_outputs)

del np_inputs, golden_outputs

# byteir compile
output_mlir_file_name = f"{WORK_FOLDER}/{unique_name}.rt.mlir"
byteir.compile(
mhlo_file, output_mlir_file_name, entry_func=func_name, target=target
)

# cp orininal mlir file
shutil.copy(
mhlo_file,
f"{WORK_FOLDER}/{os.path.basename(mhlo_file).split('.')[0]}.stablehlo.mlir",
)

except Exception as e:
return TestResult(
unique_name=unique_name,
compilation_error="".join(
traceback.format_exception(type(e), e, e.__traceback__)
),
runtime_error=None,
numerical_error=None,
performance_result=None,
)

res = TestResult(
unique_name=unique_name,
compilation_error=None,
runtime_error=None,
numerical_error=None,
performance_result=None,
)

return res


class BRTBackend:
def __init__(self, device, brt_file_path):
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete
Expand Down Expand Up @@ -226,7 +156,7 @@ def profile(self, inputs, outputs, warmup_trials=10, run_trials=50):
return ((end - start) * 1000) / run_trials


def compile_and_run_mlir(mhlo_file, target, verbose, mode="numerical", workdir="./local_test", unique_name=None, **kwargs):
def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical", unique_name=None, **kwargs):
try:
data_generator = MLIRDataGenerator(mhlo_file, target)
entry_func_name = data_generator.entry_func_name
Expand Down Expand Up @@ -320,7 +250,8 @@ def compile_and_run_mlir(mhlo_file, target, verbose, mode="numerical", workdir="
)


def compile_and_run_torch(test, target, verbose, mode="numerical", workdir="./local_test"):
def compile_and_run_torch(test, target, workdir, verbose, mode="numerical"):
from torch_e2e_testing.framework import generate_golden_trace
import torch_frontend

cur_device = get_target_device(target)
Expand Down
119 changes: 91 additions & 28 deletions tests/numerical_test/gen_golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
import argparse
import os
import re
from execute import gen_golden_mlir
from utils import report_results
import sys
import shutil
import traceback

import numpy as np
from execute import MLIRDataGenerator
from reporting import TestResult, report_results
import byteir

parser = argparse.ArgumentParser()
parser.add_argument("--target",
Expand All @@ -28,21 +33,93 @@
help="target device name")
parser.add_argument("-g",
"--golden",
default="/tmp/mlir_cpu_golden",
type=str,
default="./local_golden",
help="mlir test golden path")
args = parser.parse_args()

EXCLUDE_MLIR_TESTS = []

# Unsupported ops
EXCLUDE_MLIR_CPU_TESTS = [
"custom_call_tf_UpperBound.mlir",
"rng.mlir",
]

MLIR_CPU_SPECIAL_INPUTS = {
"log_plus_one.mlir": ["uniform", 0.5, 1.0],
}
def gen_golden_mlir(mhlo_file, target, golden_dir, num=5):
"""
Arguements:
@param mhlo_file: Source stablehlo/mhlo file.
@param target: Target name like `cpu`,`cuda`
@param num: Numbers of generated golden in/output, default to 5.
"""

def save_np_data(fpath: str, data):
np.save(fpath, data)

file_base_name = os.path.basename(mhlo_file).split(".")[0]
unique_name = file_base_name + "." + target
try:
data_generator = MLIRDataGenerator(mhlo_file, target)
func_name = data_generator.entry_func_name

os.makedirs(golden_dir, exist_ok=True)
WORK_FOLDER = golden_dir + f"/{unique_name}"
os.makedirs(WORK_FOLDER, exist_ok=True)

# if need special inputs, only iterate 1 time
if data_generator.need_special_inputs():
num = 1

for idx in range(0, num):
np_inputs = data_generator.generate_np_inputs()

# run golden
from mhlo_tools.ir_executor import Interpreter
interp = Interpreter.load_from_file(mhlo_file, is_stablehlo=True)
golden_outputs = interp.call_function(func_name, np_inputs)

# dump to local file
save_np_data(WORK_FOLDER + f"/inputs.{str(idx)}.npy", np_inputs)
save_np_data(WORK_FOLDER + f"/outputs.{str(idx)}.npy", golden_outputs)

del np_inputs, golden_outputs

# byteir compile
output_mlir_file_name = f"{WORK_FOLDER}/{unique_name}.rt.mlirbc"
byteir.compile(
mhlo_file, output_mlir_file_name, entry_func=func_name, target=target
)
# cp orininal mlir file
shutil.copy(
mhlo_file,
f"{WORK_FOLDER}/{file_base_name}.stablehlo.mlir",
)
# serialize to stablehlo bytecode
from byteir._mlir_libs._stablehlo import serialize_portable_artifact, get_current_version
bytes = serialize_portable_artifact(data_generator.module.operation.get_asm(), get_current_version())
with open(f"{WORK_FOLDER}/{file_base_name}.stablehlo.mlirbc", "wb") as f:
f.write(bytes)

except Exception as e:
return TestResult(
unique_name=unique_name,
compilation_error="".join(
traceback.format_exception(type(e), e, e.__traceback__)
),
runtime_error=None,
numerical_error=None,
performance_result=None,
)

res = TestResult(
unique_name=unique_name,
compilation_error=None,
runtime_error=None,
numerical_error=None,
performance_result=None,
)

return res



def gen_mlir_cpu_golden():
Expand All @@ -58,31 +135,17 @@ def gen_mlir_cpu_golden():
f = os.path.join(directory, filename)
# checking if it is a file
if os.path.isfile(f) and filename not in EXCLUDE_MLIR_CPU_TESTS:
mlir_tests.append([
f, MLIR_CPU_SPECIAL_INPUTS[filename]
if filename in MLIR_CPU_SPECIAL_INPUTS else None
])
mlir_tests.append(f)

results = []
for test in mlir_tests:
fpath = test[0]
fpath = test
cur_golden_dir = args.golden
if test[1] is None:
res = gen_golden_mlir(fpath,
cpu_target,
golden_dir=cur_golden_dir,
num=5)
else:
res = gen_golden_mlir(fpath,
cpu_target,
golden_dir=cur_golden_dir,
num=5,
mode=test[1][0],
low=test[1][1],
high=test[1][2])

res = gen_golden_mlir(fpath,
cpu_target,
cur_golden_dir,
num=5)
results.append(res)

return results


Expand Down
20 changes: 13 additions & 7 deletions tests/numerical_test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_local_gpu_arch():
gpu_arch = int(gpu_arch[3:])
return gpu_arch

def run(target, filter, mode="numerical", verbose=False):
def run(target, filter, workdir, mode="numerical", verbose=False):
if target == "dynamo":
from torch_dynamo_e2e_testing.execute import run_torch_dynamo_tests
gpu_arch = get_local_gpu_arch()
Expand All @@ -150,20 +150,20 @@ def run(target, filter, mode="numerical", verbose=False):
if test in GLOBAL_TORCH_TEST_REGISTRY_NAMES:
results.append(
compile_and_run_torch(
GLOBAL_TORCH_TEST_REGISTRY[test], target, verbose, mode
GLOBAL_TORCH_TEST_REGISTRY[test], target, workdir, verbose, mode
)
)
else:
if target == "cpu":
results.append(
compile_and_run_mlir(
os.path.join(CPU_MLIR_TEST_DIR, test), target, verbose, mode
os.path.join(CPU_MLIR_TEST_DIR, test), target, workdir, verbose, mode
)
)
else:
results.append(
compile_and_run_mlir(
os.path.join(CUDA_MLIR_TEST_DIR, test), target, verbose, mode
os.path.join(CUDA_MLIR_TEST_DIR, test), target, workdir, verbose, mode
)
)
return results
Expand Down Expand Up @@ -211,7 +211,13 @@ def parse_args():
"--verbose",
default=False,
action="store_true",
help="report test results with additional detail",
help="Report test results with additional detail",
)
parser.add_argument(
"--workdir",
type=str,
default="./local_test",
help="Work directory to save compiled outputs",
)
args = parser.parse_args()
return args
Expand All @@ -223,9 +229,9 @@ def main():
results = []
if args.target == "all":
for target in ["cpu", "cuda", "cuda_with_ait", "dynamo"]:
results += run(target, args.filter)
results += run(target, args.filter, args.workdir)
else:
results += run(args.target, args.filter, mode=args.mode, verbose=args.verbose)
results += run(args.target, args.filter, args.workdir, mode=args.mode, verbose=args.verbose)

failed = report_results(results)
sys.exit(1 if failed else 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/numerical_test/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()

result = compile_and_run_mlir(args.input_mlir_path, args.target, args.verbose, mode=args.mode, workdir=args.workdir, unique_name=args.name)
result = compile_and_run_mlir(args.input_mlir_path, args.target, args.workdir, args.verbose, mode=args.mode, unique_name=args.name)
failed = report_results([result])
sys.exit(1 if failed else 0)

0 comments on commit 1daeb2f

Please sign in to comment.