From 5d44663cbe1eb86e28f5bc670bced4573897b0d7 Mon Sep 17 00:00:00 2001 From: "akmkhale@r01u27fcp" Date: Mon, 1 Apr 2024 21:58:19 -0700 Subject: [PATCH 1/2] Add --aten-transform flag to the command line args --- projects/pt1/e2e_testing/main.py | 6 +++ .../configs/torchdynamo.py | 39 +++++++++++++++++++ .../python/torch_mlir_e2e_test/framework.py | 7 +++- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index 957c8b9584ab..3ec53b0352e5 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -113,10 +113,16 @@ def _get_argparse(): default=False, action="store_true", help="Enable debug timings collection.") + parser.add_argument("--aten-transform", + default=False, + action="store_true", + help="Replace aten.add.Tensor aten.add.Scalar, for ResNet like models.") return parser def main(): args = _get_argparse().parse_args() + if args.aten_transform: + args.dump.append("aten-transform") opts = TestOptions(dumps=args.dump, use_kernels=args.use_kernels, debug_timer=args.enable_timer) all_test_unique_names = set( diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index e0639c41237c..2d2165e6c9bc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -61,6 +61,37 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: return False return True + +def transform(gm: torch.fx.GraphModule): + print("============transform============") + # Modify gm.graph + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == 'call_function': + # The target attribute is the function + # that call_function calls. + # call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {}) + if node.target == torch.ops.aten.add.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + print("skipping --- node: ", node, "args: ", node.args, " kwargs: ", node.kwargs) + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.add.Scalar + print("node: ", node, "args: ", node.args, " kwargs: ", node.kwargs) + print("argtypes: ", type(node.args[0]), type(node.args[1])) + if node.target == torch.ops.aten.mul.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + print("skipping --- node: ", node, "args: ", node.args, " kwargs: ", node.kwargs) + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.mul.Scalar + print("node: ", node, "args: ", node.args) + # node.target = torch.mul + + gm.graph.lint() # Does some checks to make sure the + # Recompile the forward() method of `gm` from its Graph + gm.recompile() + print("============transform============") + # Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to # torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument # Cannot be done on earlier stage, e.g. in _FXGraphImporter as it @@ -136,6 +167,14 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule, with open(f"{model._get_name()}.{symbol}-fx-graph.txt", "w") as f: print(gm.graph, file=f) + if opts.is_dump_enabled("aten-transform"): + transform(gm) + if opts.is_dump_enabled("fx-graph"): + with open(f"{model._get_name()}.{symbol}-fx-graph-xformed.txt", "w") as f: + print(gm.graph, file=f) + with open(f"{model._get_name()}.{symbol}-fx-graph-xformed.py", "w") as f: + print(gm.code, file=f) + nonlocal mlir_module *_, model_name, nth_graph = get_aot_compilation_context() mlir_module = import_fx_graph_as_func(gm.graph, model_name) diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index b4041629d22d..7c77dfe50187 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -167,7 +167,7 @@ def wrapper_debug_timer(*args, **kwargs): class TestOptions: """Test run options.""" - dump_choices = ["all", "fx-graph", "torch-mlir", "linalg-mlir", "llvm-mlir", "torch-mlir-lowering", "linalg-mlir-lowering", "obj"] + dump_choices = ["all", "fx-graph", "aten-transform", "torch-mlir", "linalg-mlir", "llvm-mlir", "torch-mlir-lowering", "linalg-mlir-lowering", "obj"] def __init__(self, *, dumps: List[str] = [], use_kernels=False, debug_timer=False, use_omp=True): self.dumps = {opt for opt in dumps} @@ -176,7 +176,10 @@ def __init__(self, *, dumps: List[str] = [], use_kernels=False, debug_timer=Fals self.use_omp = use_omp def is_dump_enabled(self, dump: str): - return dump in self.dumps or "all" in self.dumps + if dump != "aten-transform": + return dump in self.dumps or "all" in self.dumps + else: + return dump in self.dumps def is_debug_timer_enabled(self): return self.debug_timer From bcc85ecb77a65430cd164371a210d9df1b1c1399 Mon Sep 17 00:00:00 2001 From: "akmkhale@r01u27fcp" Date: Mon, 1 Apr 2024 22:17:55 -0700 Subject: [PATCH 2/2] Adding ResNext model --- .../test_suite/vision_models.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/vision_models.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/vision_models.py index 34f3b9c697fc..f3ebeca7cd65 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/vision_models.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/vision_models.py @@ -105,3 +105,17 @@ def forward(self, img): @register_test_case(module_factory=lambda: MobilenetV3Module()) def MobilenetV3Module_basic(module, tu: TestUtils): module.forward(tu.rand(1, 3, 224, 224)) + + +def ResNext(): + model = models.resnext50_32x4d() + model.eval() + return model + +@register_test_case(module_factory=lambda: ResNext()) +def ResNext_basic(module, tu: TestUtils): + # out = module.forward(tu.randint(1, 11, high=13000)) + out = module.forward(tu.rand(1, 3, 224, 224)) + # model.forward(input_ids=input_ids.input_ids, attention_mask=input_ids.attention_mask, output_hidden_states=False, use_cache=False) + # print("gen tokens: ", gen_tokens) + return out \ No newline at end of file