Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime] add BRTBackend, add byteir.compile_from_string #434

Merged
merged 5 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/python/byteir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# ==============================================================================

from ._mlir_libs._byteir import *
from .compile import compile, DebugType
from .compile import compile, compile_from_string, DebugType
53 changes: 40 additions & 13 deletions compiler/python/byteir/compile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from pathlib import Path
import os
from shutil import copymode
from typing import Union

from . import ir
from .passmanager import PassManager
Expand Down Expand Up @@ -366,9 +366,8 @@ def _compile_cpu(
if (module.operation.get_asm() != deserialized_module.operation.get_asm()):
raise ValueError("module asm has be changed after byre serialization")


def compile(
input_file_path: str,
def compile_from_string(
input_string_or_bytes: Union[str, bytes],
output_file_path: str,
entry_func: str = "main",
target: str = "cuda",
Expand All @@ -391,18 +390,12 @@ def compile(
gpu_arch_num = int(gpu_arch[3:])
if enable_tf32:
assert gpu_arch_num >= 80, "1xtf32 only support on gpu >= sm_80"
print(f"Compiling {os.path.basename(input_file_path)} to {gpu_arch}")
print(f"[ByteIR] Compiling to {gpu_arch}")
elif _device == "cpu":
print(f"Compiling {os.path.basename(input_file_path)} to {cpu_arch}")
print(f"[ByteIR] Compiling to {cpu_arch}")

### load from .mlir or .mlirbc
from byteir._mlir_libs._stablehlo import deserialize_portable_artifact
context = ir.Context()
if input_file_path.endswith(".mlirbc"):
module_bytes = deserialize_portable_artifact(open(input_file_path, "rb").read())
module = ir.Module.parse(module_bytes, context)
else:
module = ir.Module.parse(open(input_file_path, "r").read(), context)
module = ir.Module.parse(input_string_or_bytes, context)
_print_verbose(module, "// IR Dump Input MLIR:") if verbose else ...

### legalize stablehlo to mhlo
Expand Down Expand Up @@ -444,3 +437,37 @@ def compile(
_compile_fn(compile_options)
else:
raise NotImplementedError("not implemented target: {}".format(target))

def compile(
input_file_path: str,
output_file_path: str,
entry_func: str = "main",
target: str = "cuda",
gpu_arch: str = "local",
cpu_arch: str = "x86_64",
byre_serial_version: str = "1.0.0",
verbose: bool = False,
enable_tf32: bool = False,
parallelism: int = 1,
disable_byteir_ait_cache: bool = False,
**kwargs,
) -> None:
### load from .mlir or .mlirbc
from byteir._mlir_libs._stablehlo import deserialize_portable_artifact
if input_file_path.endswith(".mlirbc"):
module_bytes = deserialize_portable_artifact(open(input_file_path, "rb").read())
else:
module_bytes = open(input_file_path, "r").read()

compile_from_string(module_bytes,
output_file_path=output_file_path,
entry_func=entry_func,
target=target,
gpu_arch=gpu_arch,
cpu_arch=cpu_arch,
byre_serial_version=byre_serial_version,
verbose=verbose,
enable_tf32=enable_tf32,
parallelism=parallelism,
disable_byteir_ait_cache=disable_byteir_ait_cache,
kwargs=kwargs)
144 changes: 144 additions & 0 deletions runtime/python/brt/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import brt
from brt.utils import brt_dtype_to_torch_dtype
import torch

import time

# BRTBackend for static shape and single device
class BRTBackend:
def __init__(self, byre_file_path, device):
assert device == "cuda" or device == "cpu"
if device == "cuda":
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete
_allocator_alloc = caching_allocator_alloc
_allocator_delete = caching_allocator_delete
_stream = torch.cuda.current_stream()._as_parameter_.value
else:
_allocator_alloc = None
_allocator_delete = None
_stream = None
self.session = brt.Session(
device=device.upper(),
alloc_func=_allocator_alloc,
free_func=_allocator_delete,
)
self.session.load(byre_file_path)
self.req = self.session.new_request_context(_stream)
self.device = device

# for static shape model, just cache shape and dtype info
self.input_arg_offsets = self.session.get_input_arg_offsets()
self.input_shapes = []
self.input_dtypes = []
for offset in self.input_arg_offsets:
self.input_shapes.append(self.session.get_static_shape(offset))
self.input_dtypes.append(brt_dtype_to_torch_dtype(self.session.get_data_type(offset)))
self.output_arg_offsets = self.session.get_output_arg_offsets()
self.output_shapes = []
self.output_dtypes = []
for offset in self.output_arg_offsets:
self.output_shapes.append(self.session.get_static_shape(offset))
self.output_dtypes.append(brt_dtype_to_torch_dtype(self.session.get_data_type(offset)))

def _check_shape_dtype(self, tensors, shapes, dtypes):
assert len(tensors) == len(shapes)
assert len(tensors) == len(dtypes)
for tensor, shape, dtype in zip(tensors, shapes, dtypes):
assert list(shape) == list(tensor.shape)
assert dtype == tensor.dtype

def _bind_inputs(self, inputs):
inputOffsetAndData = []
for offset, input in zip(self.input_arg_offsets, inputs):
inputOffsetAndData.append((offset, input.data_ptr()))
self.req.bind_args(inputOffsetAndData)

def _bind_outputs(self, outputs):
outputOffsetAndData = []
for offset, output in zip(self.output_arg_offsets, outputs):
outputOffsetAndData.append((offset, output.data_ptr()))
self.req.bind_args(outputOffsetAndData)

def run(self, inputs, check=True):
if check:
self._check_shape_dtype(inputs, self.input_shapes, self.input_dtypes)

# alloc outputs
outputs = []
for shape, dtype in zip(self.output_shapes, self.output_dtypes):
outputs.append(torch.empty(shape, dtype=dtype, device=self.device))

self._bind_inputs(inputs)
self._bind_outputs(outputs)

# run
self.req.finish_io_binding()
self.req.run()
self.req.sync()

return outputs

def profile(self, inputs, check=True, warmup_trials=10, run_trials=50):
if check:
self._check_shape_dtype(inputs, self.input_shapes, self.input_dtypes)

# alloc outputs
outputs = []
for shape, dtype in zip(self.output_shapes, self.output_dtypes):
outputs.append(torch.empty(shape, dtype=dtype, device=self.device))

self._bind_inputs(inputs)
self._bind_outputs(outputs)
self.req.finish_io_binding()

# warmup
for _ in range(warmup_trials):
self.req.run()
self.req.sync()

start = time.time()
for _ in range(run_trials):
self.req.run()
self.req.sync()
end = time.time()
avg = ((end - start) * 1000) / run_trials

return outputs, avg

def run_with_outputs(self, inputs, outputs, check=True):
if check:
self._check_shape_dtype(inputs, self.input_shapes, self.input_dtypes)
self._check_shape_dtype(outputs, self.output_shapes, self.output_dtypes)

self._bind_inputs(inputs)
self._bind_outputs(outputs)

self.req.finish_io_binding()
self.req.run()
self.req.sync()

def profile_with_outputs(self, inputs, outputs, check=True, warmup_trials=10, run_trials=50):
if check:
self._check_shape_dtype(inputs, self.input_shapes, self.input_dtypes)
self._check_shape_dtype(outputs, self.output_shapes, self.output_dtypes)

self._bind_inputs(inputs)
self._bind_outputs(outputs)
self.req.finish_io_binding()

# warmup
for _ in range(warmup_trials):
self.req.run()
self.req.sync()

start = time.time()
for _ in range(run_trials):
self.req.run()
self.req.sync()
end = time.time()
avg = ((end - start) * 1000) / run_trials

return avg


# TODO: add BRTDynamicShapeBackend and BRTNCCLBackend
69 changes: 7 additions & 62 deletions tests/numerical_test/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from reporting import TestResult

import brt
from brt.backend import BRTBackend
import byteir
from byteir import ir
from byteir._backend_registry import get_target_device
Expand Down Expand Up @@ -100,62 +101,6 @@ def generate_torch_outputs(self, device="cpu") -> List[torch.Tensor]:
return outputs


class BRTBackend:
def __init__(self, device, brt_file_path):
from torch.cuda.memory import caching_allocator_alloc, caching_allocator_delete

_allocator_alloc = caching_allocator_alloc if device == "cuda" else None
_allocator_delete = caching_allocator_delete if device == "cuda" else None
_stream = (
torch.cuda.current_stream()._as_parameter_.value
if device == "cuda"
else None
)
self.session = brt.Session(
device=device.upper(),
alloc_func=_allocator_alloc,
free_func=_allocator_delete,
)
self.session.load(brt_file_path)
self.req = self.session.new_request_context(_stream)

def execute(self, inputs, outputs):
# TODO(lyq): how to support dynamic shape?
assert len(self.session.get_input_arg_offsets()) == len(inputs)
assert len(self.session.get_output_arg_offsets()) == len(outputs)
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(arg.shape)
self.req.bind_arg(offset, arg.data_ptr())
for offset, ret in zip(self.session.get_output_arg_offsets(), outputs):
assert list(self.session.get_static_shape(offset)) == list(ret.shape)
self.req.bind_arg(offset, ret.data_ptr())
self.req.finish_io_binding()
self.req.run()
self.req.sync()

def profile(self, inputs, outputs, warmup_trials=10, run_trials=50):
assert len(self.session.get_input_arg_offsets()) == len(inputs)
assert len(self.session.get_output_arg_offsets()) == len(outputs)
for offset, arg in zip(self.session.get_input_arg_offsets(), inputs):
assert list(self.session.get_static_shape(offset)) == list(arg.shape)
self.req.bind_arg(offset, arg.data_ptr())
for offset, ret in zip(self.session.get_output_arg_offsets(), outputs):
assert list(self.session.get_static_shape(offset)) == list(ret.shape)
self.req.bind_arg(offset, ret.data_ptr())
self.req.finish_io_binding()

for _ in range(warmup_trials):
self.req.run()
self.req.sync()

start = time.time()
for _ in range(run_trials):
self.req.run()
self.req.sync()
end = time.time()
return ((end - start) * 1000) / run_trials


def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical", unique_name=None, **kwargs):
if unique_name is None:
unique_name = os.path.basename(mhlo_file).split(".")[0] + "." + target
Expand Down Expand Up @@ -195,7 +140,7 @@ def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical",
# brt runtime
try:
cur_device = get_target_device(target)
brt_backend = BRTBackend(cur_device, output_mlir_file_name)
brt_backend = BRTBackend(output_mlir_file_name, cur_device)

torch_inputs = []
for np_input in np_inputs:
Expand All @@ -204,9 +149,9 @@ def compile_and_run_mlir(mhlo_file, target, workdir, verbose, mode="numerical",
torch_outputs = data_generator.generate_torch_outputs(cur_device)

if mode == "numerical":
brt_backend.execute(torch_inputs, torch_outputs)
brt_backend.run_with_outputs(torch_inputs, torch_outputs)
else:
avg_time = brt_backend.profile(torch_inputs, torch_outputs)
avg_time = brt_backend.profile_with_outputs(torch_inputs, torch_outputs)
return TestResult(
unique_name=unique_name,
compilation_error=None,
Expand Down Expand Up @@ -299,11 +244,11 @@ def compile_and_run_torch(test, target, workdir, verbose, mode="numerical"):

# runtime
try:
brt_backend = BRTBackend(cur_device, output_mlir_file_name)
brt_backend = BRTBackend(output_mlir_file_name, cur_device)
if mode == "numerical":
brt_backend.execute(torch_inputs, torch_outputs)
brt_backend.run_with_outputs(torch_inputs, torch_outputs)
else:
avg_time = brt_backend.profile(torch_inputs, torch_outputs)
avg_time = brt_backend.profile_with_outputs(torch_inputs, torch_outputs)
return TestResult(
unique_name=unique_name,
compilation_error=None,
Expand Down
Loading