From 8cd0731b32f0d7abc0635f2c3e25e24329cfdee9 Mon Sep 17 00:00:00 2001 From: 0x00-pl <0x00.pl@gmail.com> Date: Wed, 18 Sep 2024 01:38:35 +0800 Subject: [PATCH] code in plai compiler convertion --- plai/pl_torch_compiler/plnn_compiler.py | 4 +++- plai/pl_torch_compiler/torch_to_plai_convertion.py | 9 +-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/plai/pl_torch_compiler/plnn_compiler.py b/plai/pl_torch_compiler/plnn_compiler.py index 022d46c..3e77a43 100644 --- a/plai/pl_torch_compiler/plnn_compiler.py +++ b/plai/pl_torch_compiler/plnn_compiler.py @@ -45,9 +45,11 @@ def __call__(self, gm: fx.GraphModule, example_inputs: Tuple[torch.Tensor, ...]) self.graph.add_argument(new_node) elif node.op == 'get_attr': raise NotImplementedError("get_attr is not supported") - else: + elif node.op in ('call_method', 'call_module', 'call_function'): new_node = converter.convert_node(node, self.node_mapping) self.graph.add_node(new_node) + else: + raise ValueError(f"Unsupported op: {node.op}") self.node_mapping_dict[node] = new_node diff --git a/plai/pl_torch_compiler/torch_to_plai_convertion.py b/plai/pl_torch_compiler/torch_to_plai_convertion.py index 1da3857..6870d16 100644 --- a/plai/pl_torch_compiler/torch_to_plai_convertion.py +++ b/plai/pl_torch_compiler/torch_to_plai_convertion.py @@ -53,10 +53,7 @@ def get_converter(self, target) -> Callable[[list, dict, Location], Node]: return self.node_converter_dict.get(func_name) def convert_node(self, node: fx.Node, node_mapping: Callable[[fx.Node], Any]) -> Node: - if node.op == 'placeholder': - assert isinstance(node.target, str) - return core_dialect.Placeholder(NamedLocation(node.target)) - elif node.op == 'call_method': + if node.op == 'call_method': raise NotImplementedError("call_method is not supported") elif node.op == 'call_module': raise NotImplementedError("call_module is not supported") @@ -65,9 +62,5 @@ def convert_node(self, node: fx.Node, node_mapping: Callable[[fx.Node], Any]) -> attrs = {k: node_mapping(v) for k, v in node.kwargs.items()} converter = self.get_converter(node.target) return converter(args, attrs, NamedLocation(node.name)) - elif node.op == 'get_attr': - raise NotImplementedError("get_attr is not supported") - elif node.op == 'output': - raise ValueError("Do not put output node in the middle of the graph") else: raise ValueError(f"Unsupported op: {node.op}")