Skip to content

Commit

Permalink
disable dynamo test
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Nov 22, 2024
1 parent 6a231df commit 7f8ded6
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 10 deletions.
24 changes: 18 additions & 6 deletions frontends/torch-frontend/third_party/patches/fx_importer.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py
index 99c8d3cf..b64e6caf 100644
index 4692d049..7e5a3cb9 100644
--- a/python/torch_mlir/extras/fx_importer.py
+++ b/python/torch_mlir/extras/fx_importer.py
@@ -54,6 +54,10 @@ from torch._subclasses import (
Expand All @@ -13,7 +13,19 @@ index 99c8d3cf..b64e6caf 100644
from torch.fx import (
Graph,
GraphModule,
@@ -1614,6 +1618,61 @@ class GraphNodeImporter:
@@ -1446,6 +1450,11 @@ class GraphNodeImporter:
elif isinstance(target, TorchOpOverload):
# Dispatch to an ATen op.
self._import_torch_op_overload(loc, node)
+ # ref: https://github.com/pytorch/pytorch/blob/main/torch/_ops.py#L1015
+ elif isinstance(target, torch._ops.OpOverloadPacket):
+ # Retrieval OpOverload from node.meta and dispatch a aten op
+ assert "original_aten" in node.meta and node.meta["original_aten"] is not None
+ self._import_torch_op_overload(loc, node, node.meta["original_aten"])
elif isinstance(target, HigherOrderOperator):
self._import_hop(loc, node, target)
else:
@@ -1615,6 +1624,61 @@ class GraphNodeImporter:
for i, value in enumerate(operation.results):
self.bind_node_value(node, value, i + bind_none)

Expand Down Expand Up @@ -75,7 +87,7 @@ index 99c8d3cf..b64e6caf 100644
def _import_torch_op_overload(
self,
loc: Location,
@@ -1655,24 +1714,30 @@ class GraphNodeImporter:
@@ -1656,24 +1720,30 @@ class GraphNodeImporter:
self._multi_result_nodes.add(node)

# Unroll operands from formal parameters, args and kwargs.
Expand Down Expand Up @@ -122,7 +134,7 @@ index 99c8d3cf..b64e6caf 100644

operation = _emit_operation(
mlir_op_name, result_types=result_types, operands=operands, loc=loc
@@ -2057,6 +2122,8 @@ def _make_vtensor_literal_op(
@@ -2058,6 +2128,8 @@ def _make_vtensor_literal_op(
) -> Operation:
mapping = py_attr_tracker.track(tensor)
if mapping.is_empty:
Expand All @@ -131,7 +143,7 @@ index 99c8d3cf..b64e6caf 100644
# check support for bfloat16
assert not (
tensor.dtype == torch.bfloat16 and ml_dtypes is None
@@ -2072,11 +2139,17 @@ def _make_vtensor_literal_op(
@@ -2073,11 +2145,17 @@ def _make_vtensor_literal_op(
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
Expand All @@ -151,7 +163,7 @@ index 99c8d3cf..b64e6caf 100644
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
@@ -2170,9 +2243,10 @@ def _emit_operation(
@@ -2171,9 +2249,10 @@ def _emit_operation(
# which haven't been generated by torch_ods_gen.py.
context = loc.context
if not context.is_registered_operation(mlir_op_name):
Expand Down
2 changes: 1 addition & 1 deletion tests/numerical_test/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def main():

results = []
if args.target == "all":
for target in ["cpu", "cuda", "cuda_with_ait", "dynamo"]:
for target in ["cpu", "cuda", "cuda_with_ait"]:#, "dynamo"]:
results += run(target, args.filter, args.workdir)
else:
results += run(args.target, args.filter, args.workdir, mode=args.mode, verbose=args.verbose)
Expand Down
7 changes: 4 additions & 3 deletions tests/numerical_test/torch_dynamo_e2e_testing/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import brt
import byteir

from torch_frontend import compile, DebugType, BYTEIR_CUSTOM_OPS, GENERIC_CUSTOM_OPS
import torch_frontend
from torch_frontend import BYTEIR_CUSTOM_OPS, GENERIC_CUSTOM_OPS
from torch_frontend import (
list_decomposed_ops,
preprocess_fx_graph,
Expand Down Expand Up @@ -94,8 +95,8 @@ def byteir_compile_fx_inner(
compile_type = "stablehlo"
backend_legal_ops = BYTEIR_CUSTOM_OPS + GENERIC_CUSTOM_OPS
with maybe_disable_fake_tensor_mode():
compiled_graph = compile(
fx_graph, inputs, compile_type, backend_legal_ops=backend_legal_ops
compiled_graph = torch_frontend.compile_dynamo_model(
fx_graph, compile_type, backend_legal_ops=backend_legal_ops, verbose=True
)
# print(compiled_graph)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from ..backend import byteir_compile_fx

# ==============================================================================

class FlashAttnModel(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -41,6 +42,7 @@ def test_flash_attn_unit():
torch.testing.assert_close(k.grad, k_clone.grad)
torch.testing.assert_close(v.grad, v_clone.grad)

# ==============================================================================

class FlashAttnFunctionalModel(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -79,6 +81,7 @@ def test_flash_attn_functional_unit():
torch.testing.assert_close(k.grad, k_clone.grad)
torch.testing.assert_close(v.grad, v_clone.grad)

# ==============================================================================

class FlashAttnKVCacheModel(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit 7f8ded6

Please sign in to comment.