Skip to content

Commit

Permalink
try to remove multiple output.
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00-pl committed Nov 29, 2024
1 parent 478f3ec commit d3e4753
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 62 deletions.
68 changes: 23 additions & 45 deletions plai/core/module.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,14 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Self

from plai.core.location import Location


class Value:
def __init__(self, node: Optional['Node'], type_notation=None):
self.node = node
self.type_notation = type_notation

def __str__(self):
return f"{self.type_notation}"

def owner(self):
return self.node


class Node(ABC):
def __init__(self, operands: List[Value], attrs: dict, loc: Location = None):
def __init__(self, operands: List['Node'], attrs: dict, loc: Location = None):
self.operands = operands
self.attrs = attrs
self.loc = loc
self.outputs = self.build_outputs()

subclass_dict = {}

@classmethod
def get_namespace(cls):
Expand All @@ -36,6 +21,8 @@ def get_op_name(cls):
else:
return f'{cls.get_namespace()}.{cls.__name__.lower()}'

subclass_dict = {}

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
op_name = cls.get_op_name()
Expand All @@ -54,51 +41,42 @@ def build(op_name: str, args: list, attrs: dict, loc: Location = None):
assert op_cls.build is not Node.build, f"Class {op_cls.__name__} must override the build method."
return op_cls.build(op_name, args, attrs, loc)

def build_outputs(self):
return [Value(self)]

def get_output(self) -> Value:
assert len(self.outputs) == 1, f"Node {self} with Type {self.get_op_name()} must have exactly one output."
return self.outputs[0]

def get_outputs(self) -> List[Value]:
return self.outputs

def to_string(self, value_name_dict: Dict[Value, str]):
return f'{self.get_op_name()}({", ".join(value_name_dict[i] for i in self.operands)}) {self.attrs if self.attrs else ""}'
def to_string(self, node_name_dict: Dict['Node', str]):
return f'{self.get_op_name()}({", ".join(node_name_dict[i] for i in self.operands)}) ' \
f'{self.attrs if self.attrs else ""}'

@staticmethod
def static_to_string(node: 'Node', value_name_dict: Dict[Value, str]):
def static_to_string(node: 'Node', node_name_dict: Dict['Node', str]):
if node is None:
return 'None'
return node.to_string(value_name_dict)
return node.to_string(node_name_dict)


class Graph:
def __init__(self, name=''):
self.name = name
self.arguments: List[Value] = []
self.arguments: List[Node] = []
self.nodes: List[Node] = []
self.outputs: List[Value] = []
self.outputs: List[Node] = []

def add_argument(self):
return self.arguments.append(Value(None))
def add_argument(self, node: Node):
return self.arguments.append(node)

def add_output(self, value: Value):
# value maybe is None
self.outputs.append(value)
def add_output(self, node: Node):
# node maybe is None
self.outputs.append(node)

def add_node(self, node: Node):
self.nodes.append(node)

def __str__(self):
value_name_dict: Dict[Optional[Value], str] = {None: 'None'}
value_name_dict = value_name_dict | {node: f'arg{idx}' for idx, node in enumerate(self.arguments)}
value_name_dict = value_name_dict | {node: f'v{idx}' for idx, node in enumerate(self.nodes)}
node_name_dict: Dict[Optional[Node], str] = {None: 'None'}
node_name_dict = node_name_dict | {node: f'arg{idx}' for idx, node in enumerate(self.arguments)}
node_name_dict = node_name_dict | {node: f'v{idx}' for idx, node in enumerate(self.nodes)}

result = f'Graph {self.name}({", ".join(value_name_dict[i] for i in self.arguments)}): \n'
result = f'Graph {self.name}({", ".join(node_name_dict[i] for i in self.arguments)}): \n'
for idx, node in enumerate(self.nodes):
args_str = ', '.join(value_name_dict[i] for i in node.get_outputs())
result += f' {idx}: {args_str} = {node.to_string(value_name_dict)}\n'
result += f' output ({", ".join(value_name_dict[i] for i in self.outputs)})\n'
name = node_name_dict[node]
result += f' {idx}: {name} = {node.to_string(node_name_dict)}\n'
result += f' output ({", ".join(node_name_dict[i] for i in self.outputs)})\n'
return result
3 changes: 0 additions & 3 deletions plai/dialect/aten_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ def build(op_name: str, args: list, attrs: dict, loc: Location = None):
assert op_name == 'max'
return Max(args[0], attrs['dim'], attrs['keepdim'], loc)

def build_outputs(self):
return [module.Value(self), module.Value(self)]


class ThresholdBackward(AtenNode):
def __init__(self, grad_output: module.Node, arg: module.Node, threshold: float, loc: Location = None):
Expand Down
10 changes: 10 additions & 0 deletions plai/dialect/torch_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ def get_namespace(cls):
return 'torch.nn'


class GetItem(TorchNode):
def __init__(self, arg: module.Node, index: module.Node, loc: Location = None):
super().__init__([arg], {'index': index}, loc)

@staticmethod
def build(op_name: str, args: list, attrs: dict, loc: Location = None):
assert op_name == 'getitem'
return GetItem(args[0], attrs['index'], loc)


class Linear(TorchNode):
def __init__(self, arg: module.Node, weight: module.Node, bias: module.Node, loc: Location = None):
"""
Expand Down
25 changes: 13 additions & 12 deletions plai/pl_torch_compiler/plnn_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def torch_node_to_core_node(node: fx.Node, node_mapping: Callable[[fx.Node], Any
elif func_name == 'torch.relu':
return torch_dialect.Relu(args[0], DummyLocation())
elif func_name == '_operator.getitem':
raise NotImplementedError("_operator.getitem is not supported")
return torch_dialect.GetItem(args[0], args[1], DummyLocation())
else:
raise NotImplementedError(f"Unsupported function: {func_name}")
elif node.op == 'get_attr':
Expand All @@ -91,17 +91,17 @@ def __init__(self):
self.graph = Graph('main_graph')
self.node_mapping_dict: Dict[torch.fx.Node, Node] = {}

def node_mapping(self, value):
if isinstance(value, Tuple):
return tuple(self.node_mapping(v) for v in value)
elif isinstance(value, List):
return [self.node_mapping(v) for v in value]
elif isinstance(value, dict):
return {k: self.node_mapping(v) for k, v in value.items()}
elif isinstance(value, torch.fx.Node):
return self.node_mapping_dict[value]
def node_mapping(self, node):
if isinstance(node, Tuple):
return tuple(self.node_mapping(v) for v in node)
elif isinstance(node, List):
return [self.node_mapping(v) for v in node]
elif isinstance(node, dict):
return {k: self.node_mapping(v) for k, v in node.items()}
elif isinstance(node, torch.fx.Node):
return self.node_mapping_dict[node]
else:
return value
return node

def __call__(self, gm: fx.GraphModule, example_inputs: Tuple[torch.Tensor, ...]) -> Callable:
# 遍历计算图中的所有节点并收集信息
Expand All @@ -115,7 +115,8 @@ def __call__(self, gm: fx.GraphModule, example_inputs: Tuple[torch.Tensor, ...])
self.graph.add_output(self.node_mapping(i))
new_node = None
elif node.op == 'placeholder':
new_node = self.graph.add_argument()
new_node = core_dialect.Placeholder(NamedLocation(node.target))
self.graph.add_argument(new_node)
elif node.op == 'get_attr':
raise NotImplementedError("get_attr is not supported")
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/module_pool/simple_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def forward(self, x):


def check_torch_compile_forward(model, compiled_model):
input_data = torch.randn(1, 10)
input_data = torch.randn(2, 5)
expected_output = model(input_data)
actual_output = compiled_model(input_data)
assert torch.allclose(expected_output, actual_output), "Output mismatch between compiled and original model"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_torch_compile_mutiple_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_torch_plnn_compile_mutiple_output():
model = lambda x: torch.max(x, dim=0)
model = lambda x: torch.max(x, dim=0)[0]
custom_compiler = plnn_compiler.CustomCompiler()
aot_backend = aot_autograd(fw_compiler=make_boxed_compiler(custom_compiler), bw_compiler=None)
compiled_model = torch.compile(model, backend=aot_backend)
Expand Down

0 comments on commit d3e4753

Please sign in to comment.