Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bufferization] Enable OneShot #10

Open
wants to merge 8 commits into
base: cpu-proto
Choose a base branch
from
Prev Previous commit
Next Next commit
Add debug timer and some quality of life options
kurapov-peter authored and Devjiu committed Dec 1, 2023
commit cc4aa384f39bff621cad24fe1fae52dc1ac3cd9e
16 changes: 15 additions & 1 deletion projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
@@ -61,6 +61,10 @@ def _get_argparse():
parser.add_argument("-f", "--filter", default=".*", help="""
Regular expression specifying which tests to include in this run.
""")
parser.add_argument("-l", "--list_tests",
default=False,
action="store_true",
help="List all available tests and exit.")
parser.add_argument("-v", "--verbose",
default=False,
action="store_true",
@@ -97,11 +101,15 @@ def _get_argparse():
default=False,
action="store_true",
help="Enable linalg ops replacement with runtime library kernel calls.")
parser.add_argument("--enable-timer",
default=False,
action="store_true",
help="Enable debug timings collection.")
return parser

def main():
args = _get_argparse().parse_args()
opts = TestOptions(dumps=args.dump, use_kernels=args.use_kernels)
opts = TestOptions(dumps=args.dump, use_kernels=args.use_kernels, debug_timer=args.enable_timer)

all_test_unique_names = set(
test.unique_name for test in GLOBAL_TEST_REGISTRY)
@@ -142,6 +150,12 @@ def main():

do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set)
available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt]

if args.list_tests is True:
for test in available_tests:
print(test.unique_name)
sys.exit(0)

if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None:
for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed:
if arg not in all_test_unique_names:
100 changes: 49 additions & 51 deletions projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
recursively_convert_to_numpy,
recursively_convert_from_numpy,
)
from torch_mlir_e2e_test.framework import TestConfig, TestOptions, Trace, TraceItem
from torch_mlir_e2e_test.framework import TestConfig, TestOptions, Trace, TraceItem, DebugTimer

DUMPS_ENABLED = True

@@ -150,54 +150,52 @@ def compile(self, program: torch.nn.Module) -> torch.nn.Module:

def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
module = jit(artifact,
item.inputs,
item.symbol,
self.opts,
output_type="linalg-on-tensors")

if self.opts.is_dump_enabled("linalg-mlir"):
with open(f"{artifact._get_name()}.{item.symbol}-linalg.mlir", "w") as f:
print(module, file=f)

ir_file = f"{artifact._get_name()}.{item.symbol}-linalg-to-llvm.txt" if self.opts.is_dump_enabled(
"linalg-mlir-lowering") else None
module = self.backend.compile(module, ir_file)

if self.opts.is_dump_enabled("llvm-mlir"):
with open(f"{artifact._get_name()}.{item.symbol}-llvm.mlir", "w") as f:
print(module, file=f)

#_dump_repr_to_file(module, 'linalg.mlir')
#module = self.backend.compile(module)
#_dump_repr_to_file(module, 'llvm.mlir')

backend_module = self.backend.load(module)
params = {
**dict(artifact.named_parameters(remove_duplicate=False)),
**dict(artifact.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = list(params_flat)
with torch.no_grad():
numpy_inputs = recursively_convert_to_numpy(params_flat +
item.inputs)
outputs = getattr(backend_module,
artifact.__class__.__name__)(*numpy_inputs)

if DUMPS_ENABLED:
print("Dumping binary module to object file...")
backend_module.ee.dump_to_object_file(f"{item.symbol}.o")
print("Done")

output = refine_result_type(outputs)

if self.opts.is_dump_enabled("obj"):
backend_module.ee.dump_to_object_file(f"{artifact._get_name()}.{item.symbol}.o")

result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output))
timing_logger = print if self.opts.is_debug_timer_enabled() else None
with DebugTimer("TorchDynamoTestConfig.run()", logger=timing_logger):
for item in trace:
with DebugTimer("JIT", logger=timing_logger):
module = jit(artifact,
item.inputs,
item.symbol,
self.opts,
output_type="linalg-on-tensors")

if self.opts.is_dump_enabled("linalg-mlir"):
with open(f"{artifact._get_name()}.{item.symbol}-linalg.mlir", "w") as f:
print(module, file=f)

ir_file = f"{artifact._get_name()}.{item.symbol}-linalg-to-llvm.txt" if self.opts.is_dump_enabled(
"linalg-mlir-lowering") else None
with DebugTimer("Backend.compile()", logger=timing_logger):
module = self.backend.compile(module, ir_file)

if self.opts.is_dump_enabled("llvm-mlir"):
with open(f"{artifact._get_name()}.{item.symbol}-llvm.mlir", "w") as f:
print(module, file=f)

with DebugTimer("Backend.load()", logger=timing_logger):
backend_module = self.backend.load(module)
params = {
**dict(artifact.named_parameters(remove_duplicate=False)),
**dict(artifact.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = list(params_flat)
with torch.no_grad():
with DebugTimer("recursively_convert_to_numpy", logger=timing_logger):
numpy_inputs = recursively_convert_to_numpy(params_flat +
item.inputs)
outputs = getattr(backend_module,
artifact.__class__.__name__)(*numpy_inputs)

with DebugTimer("refine_result_type", logger=timing_logger):
output = refine_result_type(outputs)

if self.opts.is_dump_enabled("obj"):
backend_module.ee.dump_to_object_file(f"{artifact._get_name()}.{item.symbol}.o")

result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output))
return result
67 changes: 66 additions & 1 deletion projects/pt1/python/torch_mlir_e2e_test/framework.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,9 @@
import sys
import traceback

import time
import functools

import torch
import multiprocess as mp

@@ -95,18 +98,80 @@ def clone_trace(trace: Trace) -> Trace:
# this type.
CompiledArtifact = TypeVar('CompiledArtifact')

class DebugTimer:
"""Basic debug timer
Usage examples:
1.
t = DebugTimer('MyName', logger=print)
t.start()
doStuff(...)
t.stop()

2.
@DebugTimer('run')
def run(...):
doStuff(...)

3.
with DebugTimer('withSmth'):
doStuff(...)
"""
def __init__(self, name=None, logger=print) -> None:
self.begin = None
self.elapsed = None
self.logger = logger
self.name = name

def start(self) -> None:
if self.begin is not None:
raise RuntimeError("Attempt to start a running timer.")
self.begin = time.perf_counter_ns()

def stop(self):
if self.begin is None:
raise RuntimeError("Attempt to stop a non-running timer.")
self.elapsed = time.perf_counter_ns() - self.begin
self.begin = None

self._report()
return self.elapsed

def _report(self):
if self.logger:
rep_line = "elapsed " + "{:.4f}".format(float(self.elapsed) / 10e6) + " ms"
rep_line = self.name + ": " + rep_line if self.name is not None else rep_line
self.logger(rep_line)

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()

def __call__(self, func):
@functools.wraps(func)
def wrapper_debug_timer(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapper_debug_timer

class TestOptions:
"""Test run options."""

dump_choices = ["all", "fx-graph", "torch-mlir", "linalg-mlir", "llvm-mlir", "torch-mlir-lowering", "linalg-mlir-lowering", "obj"]

def __init__(self, *, dumps: List[str] = [], use_kernels=False):
def __init__(self, *, dumps: List[str] = [], use_kernels=False, debug_timer=False):
self.dumps = {opt for opt in dumps}
self.use_kernels = use_kernels
self.debug_timer = debug_timer

def is_dump_enabled(self, dump: str):
return dump in self.dumps or "all" in self.dumps

def is_debug_timer_enabled(self):
return self.debug_timer

class TestConfig(abc.ABC):
"""The interface implemented by backends to run tests.

Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from torch_mlir.execution_engine import *
from torch_mlir.runtime import *
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.framework import TestOptions
from torch_mlir_e2e_test.framework import TestOptions, DebugTimer

from .abc import LinalgOnTensorsBackend
from .refbackend import RefBackendInvoker
@@ -118,14 +118,15 @@ def compile(self, imported_module: Module, ir_file: str = None):
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""

run_pipeline_with_repro_report(
imported_module, _build_lowering_pipeline(self._opts),
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend", ir_file)
with DebugTimer('CpuProtoLinalgOnTensorsBackend.compile()', logger=print if self._opts.debug_timer else None):
run_pipeline_with_repro_report(
imported_module, _build_lowering_pipeline(self._opts),
"Lowering Linalg-on-Tensors IR to LLVM with RefBackend", ir_file)
return imported_module

def load(self, module) -> RefBackendInvoker:
"""Loads a compiled artifact into the runtime."""

return RefBackendInvoker(module,
shared_libs=_collect_shared_libs(self._opts))
with DebugTimer('CpuProtoLinalgOnTensorsBackend.load()', logger=print if self._opts.debug_timer else None):
invoker = RefBackendInvoker(module,
shared_libs=_collect_shared_libs(self._opts))
return invoker
Original file line number Diff line number Diff line change
@@ -15,6 +15,8 @@

from .abc import LinalgOnTensorsBackend

from torch_mlir_e2e_test.framework import DebugTimer

__all__ = [
"RefBackendLinalgOnTensorsBackend",
]
@@ -80,9 +82,10 @@ def get_ctype_func(func_name):

class RefBackendInvoker:

def __init__(self, module, shared_libs=None):
def __init__(self, module, shared_libs=None, logger=None):
self.ee = ExecutionEngine(module, shared_libs=shared_libs)
self.result = None
self.logger = logger

return_funcs = get_return_funcs(module)

@@ -105,14 +108,15 @@ def consume_return_funcs(*args):
def __getattr__(self, function_name: str):

def invoke(*args):
ffi_args = []
for arg in args:
assert_arg_type_is_supported(arg.dtype)
ffi_args.append(
ctypes.pointer(
ctypes.pointer(get_unranked_memref_descriptor(arg))))

self.ee.invoke(function_name, *ffi_args)
with DebugTimer('refbackend.invoke() args conversion', logger=self.logger):
ffi_args = []
for arg in args:
assert_arg_type_is_supported(arg.dtype)
ffi_args.append(
ctypes.pointer(
ctypes.pointer(get_unranked_memref_descriptor(arg))))
with DebugTimer('ExecutionEngine.invoke()', logger=self.logger):
self.ee.invoke(function_name, *ffi_args)
result = self.result
assert result is not None, "Invocation didn't produce a result"
self.result = None
@@ -203,4 +207,4 @@ def compile(self, imported_module: Module, ir_file: str = None):

def load(self, module) -> RefBackendInvoker:
"""Loads a compiled artifact into the runtime."""
return RefBackendInvoker(module)
return RefBackendInvoker(module, shared_libs=[])