Skip to content

Commit

Permalink
code in plai compiler convertion
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00-pl committed Jan 3, 2025
1 parent 8d57eb6 commit 8cd0731
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
4 changes: 3 additions & 1 deletion plai/pl_torch_compiler/plnn_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def __call__(self, gm: fx.GraphModule, example_inputs: Tuple[torch.Tensor, ...])
self.graph.add_argument(new_node)
elif node.op == 'get_attr':
raise NotImplementedError("get_attr is not supported")
else:
elif node.op in ('call_method', 'call_module', 'call_function'):
new_node = converter.convert_node(node, self.node_mapping)
self.graph.add_node(new_node)
else:
raise ValueError(f"Unsupported op: {node.op}")

self.node_mapping_dict[node] = new_node

Expand Down
9 changes: 1 addition & 8 deletions plai/pl_torch_compiler/torch_to_plai_convertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def get_converter(self, target) -> Callable[[list, dict, Location], Node]:
return self.node_converter_dict.get(func_name)

def convert_node(self, node: fx.Node, node_mapping: Callable[[fx.Node], Any]) -> Node:
if node.op == 'placeholder':
assert isinstance(node.target, str)
return core_dialect.Placeholder(NamedLocation(node.target))
elif node.op == 'call_method':
if node.op == 'call_method':
raise NotImplementedError("call_method is not supported")
elif node.op == 'call_module':
raise NotImplementedError("call_module is not supported")
Expand All @@ -65,9 +62,5 @@ def convert_node(self, node: fx.Node, node_mapping: Callable[[fx.Node], Any]) ->
attrs = {k: node_mapping(v) for k, v in node.kwargs.items()}
converter = self.get_converter(node.target)
return converter(args, attrs, NamedLocation(node.name))
elif node.op == 'get_attr':
raise NotImplementedError("get_attr is not supported")
elif node.op == 'output':
raise ValueError("Do not put output node in the middle of the graph")
else:
raise ValueError(f"Unsupported op: {node.op}")

0 comments on commit 8cd0731

Please sign in to comment.