Skip to content

Commit

Permalink
using name
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00-pl committed Dec 17, 2024
1 parent c64cea3 commit f58c953
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions plai/pl_torch_compiler/plnn_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch._ops import OpOverload

from plai.core import core_dialect
from plai.core.location import NamedLocation, DummyLocation
from plai.core.location import NamedLocation
from plai.core.module import Graph, Node
from plai.dialect import aten_dialect, torch_dialect

Expand Down Expand Up @@ -48,34 +48,34 @@ def torch_node_to_core_node(node: fx.Node, node_mapping: Callable[[fx.Node], Any
args = [node_mapping(arg) for arg in node.args]
attrs = {k: v for k, v in node.kwargs.items()}
if func_name == 'aten::view':
return aten_dialect.View(args[0], args[1], DummyLocation())
return aten_dialect.View(args[0], args[1], NamedLocation(node.name))
elif func_name == 'aten::detach':
return aten_dialect.Relu(args[0], DummyLocation())
return aten_dialect.Relu(args[0], NamedLocation(node.name))
elif func_name == 'aten::t':
return core_dialect.Transpose(args[0], DummyLocation())
return core_dialect.Transpose(args[0], NamedLocation(node.name))
elif func_name == 'aten::addmm':
return aten_dialect.AddMm(
args[0], args[1], args[2],
attrs.get('beta', 1), attrs.get('alpha', 1),
DummyLocation()
NamedLocation(node.name)
)
elif func_name == 'aten::mm':
return aten_dialect.Mm(args[0], args[1], DummyLocation())
return aten_dialect.Mm(args[0], args[1], NamedLocation(node.name))
elif func_name == 'aten::relu':
return aten_dialect.Relu(args[0], DummyLocation())
return aten_dialect.Relu(args[0], NamedLocation(node.name))
elif func_name == 'aten::max.dim':
keepdim = args[2] if len(args) == 3 else False
return aten_dialect.Max(args[0], args[1], keepdim, DummyLocation())
return aten_dialect.Max(args[0], args[1], keepdim, NamedLocation(node.name))
elif func_name == 'aten::sum.dim_IntList':
return aten_dialect.Sum(args[0], args[1], args[2], DummyLocation())
return aten_dialect.Sum(args[0], args[1], args[2], NamedLocation(node.name))
elif func_name == 'aten::threshold_backward':
return aten_dialect.ThresholdBackward(args[0], args[1], args[2], DummyLocation())
return aten_dialect.ThresholdBackward(args[0], args[1], args[2], NamedLocation(node.name))
elif func_name == 'torch._C._nn.linear':
return torch_dialect.Linear(args[0], args[1], args[2], DummyLocation())
return torch_dialect.Linear(args[0], args[1], args[2], NamedLocation(node.name))
elif func_name == 'torch.relu':
return torch_dialect.Relu(args[0], DummyLocation())
return torch_dialect.Relu(args[0], NamedLocation(node.name))
elif func_name == '_operator.getitem':
return torch_dialect.GetItem(args[0], args[1], DummyLocation())
return torch_dialect.GetItem(args[0], args[1], NamedLocation(node.name))
else:
raise NotImplementedError(f"Unsupported function: {func_name}")
elif node.op == 'get_attr':
Expand Down

0 comments on commit f58c953

Please sign in to comment.