Skip to content

Commit

Permalink
Add transform function
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Mar 15, 2024
1 parent 391d073 commit 090493a
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,39 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool:
return True


def transform(gm: torch.fx.GraphModule):

print("============transform============")
# Modify gm.graph
for node in gm.graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
# call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {})
if node.target == torch.ops.aten.add.Tensor:
if len(node.args) != 2 or node.kwargs != {}:
print("skipping --- node: ", node, "args: ", node.args, " kwargs: ", node.kwargs)
elif not isinstance(node.args[1], torch.fx.node.Node):
node.target = torch.ops.aten.add.Scalar
print("node: ", node, "args: ", node.args, " kwargs: ", node.kwargs)
print("argtypes: ", type(node.args[0]), type(node.args[1]))
if node.target == torch.ops.aten.mul.Tensor:
if len(node.args) != 2 or node.kwargs != {}:
print("skipping --- node: ", node, "args: ", node.args, " kwargs: ", node.kwargs)
elif not isinstance(node.args[1], torch.fx.node.Node):
node.target = torch.ops.aten.mul.Scalar
print("node: ", node, "args: ", node.args)
# node.target = torch.mul

gm.graph.lint() # Does some checks to make sure the

# Recompile the forward() method of `gm` from its Graph
gm.recompile()
print("============transform============")


def jit(
model: torch.nn.Module,
example_args: _example_args,
Expand Down Expand Up @@ -100,6 +133,8 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule,
if opts.is_dump_enabled("fx-graph"):
with open(f"{model._get_name()}.{symbol}-fx-graph.txt", "w") as f:
print(gm.graph, file=f)

transform(gm)

nonlocal mlir_module
*_, model_name, nth_graph = get_aot_compilation_context()
Expand Down

0 comments on commit 090493a

Please sign in to comment.