Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bufferization] Enable OneShot #10

Open
wants to merge 8 commits into
base: cpu-proto
Choose a base branch
from
Prev Previous commit
Next Next commit
Add --dump option for e2e tests. (#3)
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
Co-authored-by: Laurent Montigny <lmontigny@users.noreply.github.com>
  • Loading branch information
ienkovich and lmontigny authored Nov 13, 2023
commit 57c01cd3cc82fe12d982ecb07adf4768e4c51691
20 changes: 18 additions & 2 deletions projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import re
import sys

from torch_mlir_e2e_test.framework import run_tests
from torch_mlir_e2e_test.framework import run_tests, TestOptions
from torch_mlir_e2e_test.reporting import report_results
from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY

@@ -77,10 +77,26 @@ def _get_argparse():
default=False,
action="store_true",
help="return exit code 0 even if the test fails to unblock pipeline")
parser.add_argument("--dump",
choices=TestOptions.dump_choices,
default=[],
action="append",
help=f"""
Available options:
"all": enable all dumps
"fx-graph": dump input FX Graph
"torch-mlir": dump generated Torch MLIR module
"linalg-mlir": dump module lowered to linalg dialect
"llvm-mlir": dump module lowered to LLVM dialect
"torch-mlir-lowering": dump after-pass results in Torch to Linalg pipeline
"linalg-mlir-lowering": dump after-pass results in Linalg to LLVM pipeline
"obj": dump compiled code to object file
""")
return parser

def main():
args = _get_argparse().parse_args()
opts = TestOptions(dumps=args.dump)

all_test_unique_names = set(
test.unique_name for test in GLOBAL_TEST_REGISTRY)
@@ -111,7 +127,7 @@ def main():
xfail_set = LTC_XFAIL_SET
crashing_set = LTC_CRASHING_SET
elif args.config == "torchdynamo":
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend(), opts=opts)
xfail_set = TORCHDYNAMO_XFAIL_SET
crashing_set = TORCHDYNAMO_CRASHING_SET

8 changes: 4 additions & 4 deletions projects/pt1/python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
@@ -268,7 +268,7 @@ def _canon_extra_library(extra_library):
return extra_library_file_name


def _lower_mlir_module(verbose, output_type, module):
def _lower_mlir_module(verbose, output_type, module, ir_dump_file = None):
if verbose:
print("\n====================")
print("Torch Backend IR")
@@ -280,7 +280,7 @@ def _lower_mlir_module(verbose, output_type, module):
if output_type == OutputType.TOSA:
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
"Lowering Torch Backend IR -> TOSA Backend IR")
"Lowering Torch Backend IR -> TOSA Backend IR", ir_dump_file)
if verbose:
print("\n====================")
print("TOSA Backend IR")
@@ -291,7 +291,7 @@ def _lower_mlir_module(verbose, output_type, module):
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ir_dump_file)
if verbose:
print("\n====================")
print("LINALG Backend IR")
@@ -302,7 +302,7 @@ def _lower_mlir_module(verbose, output_type, module):
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
"Lowering Torch Backend IR -> StableHLO Backend IR")
"Lowering Torch Backend IR -> StableHLO Backend IR", ir_dump_file)
if verbose:
print("\n====================")
print("StableHLO Backend IR")
25 changes: 23 additions & 2 deletions projects/pt1/python/torch_mlir/compiler_utils.py
Original file line number Diff line number Diff line change
@@ -25,9 +25,23 @@ def get_module_name_for_debug_dump(module):
class TorchMlirCompilerError(Exception):
pass

class StderrToFile:
def __init__(self, file: str):
self._file_name = file

def __enter__(self):
self._fd = os.open(self._file_name, os.O_WRONLY | os.O_CREAT)
self._old_stderr_fd = os.dup(2)
os.dup2(self._fd, 2)

def __exit__(self, *args):
os.dup2(self._old_stderr_fd, 2)
os.close(self._fd)

def run_pipeline_with_repro_report(module,
pipeline: str,
description: str):
description: str,
ir_dump_file: str = None):
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
module_name = get_module_name_for_debug_dump(module)
try:
@@ -38,7 +52,14 @@ def run_pipeline_with_repro_report(module,
# Lower module in place to make it ready for compiler backends.
with module.context:
pm = PassManager.parse(pipeline)
pm.run(module.operation)
if ir_dump_file is not None:
module.context.enable_multithreading(False)
pm.enable_ir_printing()
with StderrToFile(ir_dump_file):
pm.run(module.operation)
module.context.enable_multithreading(True)
else:
pm.run(module.operation)
except Exception as e:
# TODO: More robust.
# - don't arbitrarily clutter up /tmp. When a test suite has many
46 changes: 40 additions & 6 deletions projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
recursively_convert_to_numpy,
recursively_convert_from_numpy,
)
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir_e2e_test.framework import TestConfig, TestOptions, Trace, TraceItem

DUMPS_ENABLED = True

@@ -65,6 +65,8 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
def jit(
model: torch.nn.Module,
example_args: _example_args,
symbol: str,
opts: TestOptions,
output_type: Union[str, "OutputType"] = OutputType.TORCH,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library=None,
@@ -95,9 +97,18 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule,
# way of differentiating between the two.
assert not _returns_empty_tuple(gm), "encountered graph that does not return anything"

if opts.is_dump_enabled("fx-graph"):
with open(f"{model._get_name()}.{symbol}-fx-graph.txt", "w") as f:
print(gm.graph, file=f)

nonlocal mlir_module
*_, model_name, nth_graph = get_aot_compilation_context()
mlir_module = import_fx_graph_as_func(gm.graph, model_name)

if opts.is_dump_enabled("torch-mlir"):
with open(f"{model._get_name()}.{symbol}-torch.mlir", "w") as f:
print(mlir_module, file=f)

return gm

my_backend = aot_autograd(fw_compiler=my_aot_autograd_backend,
@@ -121,15 +132,18 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule,
"Lowering TorchFX IR -> Torch Backend IR",
)

return _lower_mlir_module(verbose, output_type, mlir_module)
ir_file = f"{model._get_name()}.{symbol}-torch-to-linanlg.txt" if opts.is_dump_enabled(
"torch-mlir-lowering") else None
return _lower_mlir_module(verbose, output_type, mlir_module, ir_file)


class TorchDynamoTestConfig(TestConfig):
"""TestConfig that runs the torch.nn.Module with TorchDynamo"""

def __init__(self, backend):
def __init__(self, backend, opts=TestOptions()):
super().__init__()
self.backend = backend
self.opts = opts

def compile(self, program: torch.nn.Module) -> torch.nn.Module:
return program
@@ -139,10 +153,26 @@ def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
for item in trace:
module = jit(artifact,
item.inputs,
item.symbol,
self.opts,
output_type="linalg-on-tensors")
_dump_repr_to_file(module, 'linalg.mlir')
module = self.backend.compile(module)
_dump_repr_to_file(module, 'llvm.mlir')

if self.opts.is_dump_enabled("linalg-mlir"):
with open(f"{artifact._get_name()}.{item.symbol}-linalg.mlir", "w") as f:
print(module, file=f)

ir_file = f"{artifact._get_name()}.{item.symbol}-linalg-to-llvm.txt" if self.opts.is_dump_enabled(
"linalg-mlir-lowering") else None
module = self.backend.compile(module, ir_file)

if self.opts.is_dump_enabled("llvm-mlir"):
with open(f"{artifact._get_name()}.{item.symbol}-llvm.mlir", "w") as f:
print(module, file=f)

#_dump_repr_to_file(module, 'linalg.mlir')
#module = self.backend.compile(module)
#_dump_repr_to_file(module, 'llvm.mlir')

backend_module = self.backend.load(module)
params = {
**dict(artifact.named_parameters(remove_duplicate=False)),
@@ -162,6 +192,10 @@ def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
print("Done")

output = refine_result_type(outputs)

if self.opts.is_dump_enabled("obj"):
backend_module.ee.dump_to_object_file(f"{artifact._get_name()}.{item.symbol}.o")

result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
11 changes: 11 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/framework.py
Original file line number Diff line number Diff line change
@@ -95,6 +95,17 @@ def clone_trace(trace: Trace) -> Trace:
# this type.
CompiledArtifact = TypeVar('CompiledArtifact')

class TestOptions:
"""Test run options."""

dump_choices = ["all", "fx-graph", "torch-mlir", "linalg-mlir", "llvm-mlir", "torch-mlir-lowering", "linalg-mlir-lowering", "obj"]

def __init__(self, dumps: List[str] = []):
self.dumps = {opt for opt in dumps}

def is_dump_enabled(self, dump: str):
return dump in self.dumps or "all" in self.dumps

class TestConfig(abc.ABC):
"""The interface implemented by backends to run tests.

Original file line number Diff line number Diff line change
@@ -182,7 +182,7 @@ class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend):
def __init__(self):
super().__init__()

def compile(self, imported_module: Module):
def compile(self, imported_module: Module, ir_file: str = None):
"""Compiles an imported module, with a flat list of functions.
The module is expected to be in linalg-on-tensors + scalar code form.
TODO: More clearly define the backend contract. Generally this will
@@ -198,7 +198,7 @@ def compile(self, imported_module: Module):

run_pipeline_with_repro_report(
imported_module, LOWERING_PIPELINE,
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend")
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend", ir_file)
return imported_module

def load(self, module) -> RefBackendInvoker: