Skip to content

Commit

Permalink
Serialize tuple types from Node (pytorch#2405)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-export-checks

Pull Request resolved: pytorch#2405

In pytorch#2271, we already added
- IntList
- DoubleList
- BoolList
- ValueList

to the schema and the runtime's Value class. Their serialization was incomplete missing two components:
1. Receiving a list in `torch.fx.Node.args`.
2. Receiving a non-tensor in `torch.fx.Node`.

This change completes #2.

Also, we introduce a specific handler for `getitem` nodes as it is required. Every `function_call` outputting non-tensor `torch.fx.Node` is followed by a special `getitem` `function_call`.
ghstack-source-id: 218541429
exported-using-ghexport

Reviewed By: SS-JIA

Differential Revision: D54789099

fbshipit-source-id: 1e58cbc0246a8c651d95e9ef6707bacb60066e2b
  • Loading branch information
jorgep31415 authored and facebook-github-bot committed Mar 13, 2024
1 parent 21cbfd6 commit d0512b6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 34 deletions.
2 changes: 2 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
from typing import final, List, Optional

import torch
Expand All @@ -30,6 +31,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.pow.Tensor_Tensor,
operator.getitem,
]
return supported

Expand Down
80 changes: 46 additions & 34 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
from typing import cast, List, Optional, Union

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
Expand All @@ -16,7 +17,7 @@
from torch.fx import Node

_ScalarType = Union[bool, int, float]
_Argument = Union[Node, List[Node], _ScalarType, List[_ScalarType], str]
_Argument = Union[Node, List[Node], TensorSpec, _ScalarType, List[_ScalarType], str]


class VkGraphBuilder:
Expand All @@ -34,6 +35,7 @@ def __init__(self, program: ExportedProgram) -> None:

@staticmethod
def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
# TODO(T182302927): Support more dtypes including float16, int(32|64).
if torch_dtype == torch.float32:
return vk_graph_schema.VkDataType.fp32
else:
Expand Down Expand Up @@ -102,33 +104,20 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
return tensor

def maybe_add_constant_tensor(self, node: Node) -> int:
const_buffer_idx = -1
constant_id = -1
if self.is_param_node(node):
const_buffer_idx = len(self.const_tensors)
constant_id = len(self.const_tensors)
self.const_tensors.append(self.get_param_tensor(node))

return const_buffer_idx

def create_single_tensor_value(self, node: Node) -> int:
constant_id = self.maybe_add_constant_tensor(node)

spec = node.meta.get("spec")
assert isinstance(spec, TensorSpec)
new_id = len(self.values)
if node not in self.node_to_value_ids:
self.node_to_value_ids[node] = new_id
else:
current_ids = self.node_to_value_ids[node]
if isinstance(current_ids, int):
current_ids = [current_ids, new_id]
else:
current_ids.append(new_id)
return constant_id

def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
# Negative id indicates that this tensor will have its own dedicated memory.
mem_obj_id = -1
if spec.mem_obj_id is not None:
mem_obj_id = spec.mem_obj_id

new_id = len(self.values)
self.values.append(
vk_graph_schema.VkValue(
value=vk_graph_schema.VkTensor(
Expand All @@ -141,16 +130,23 @@ def create_single_tensor_value(self, node: Node) -> int:
)
return new_id

def create_tensor_values(self, node: Node) -> int:
def create_node_value(self, node: Node) -> int:
spec = node.meta.get("spec")
if isinstance(spec, TensorSpec):
return self.create_single_tensor_value(node)
constant_id = self.maybe_add_constant_tensor(node)
new_id = self.create_tensor_value(spec, constant_id)
self.node_to_value_ids[node] = new_id
return new_id
elif isinstance(spec, tuple):
# Create a Value for each element in the tuple, wrap Values in a
# ValueList, and map the Node to the ValueList id.
new_id = self.create_value_list_value(spec)
self.node_to_value_ids[node] = new_id
return new_id
else:
raise RuntimeError(
"Creating values for nodes with collection types is not supported yet."
)
raise RuntimeError(f"Cannot create value for spec of type {type(spec)}")

def create_value_list_value(self, arg: List[Node]) -> int:
def create_value_list_value(self, arg: List[Node] | tuple) -> int:
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.ValueList(
Expand Down Expand Up @@ -201,14 +197,15 @@ def create_string_value(self, string: str) -> int:

def get_or_create_value_for(self, arg: _Argument):
if isinstance(arg, Node):
# If the value has already been created, return the existing id
# If the Node has already been processed, return the existing id.
if arg in self.node_to_value_ids:
return self.node_to_value_ids[arg]
# Return id for a newly created value
return self.create_tensor_values(arg)
return self.create_node_value(arg)
elif isinstance(arg, list) and isinstance(arg[0], Node):
# pyre-ignore[6]
return self.create_value_list_value(arg)
elif isinstance(arg, TensorSpec):
return self.create_tensor_value(arg)
elif isinstance(arg, _ScalarType):
return self.create_scalar_value(arg)
elif isinstance(arg, list) and isinstance(arg[0], _ScalarType):
Expand All @@ -220,13 +217,25 @@ def get_or_create_value_for(self, arg: _Argument):
raise RuntimeError(f"Cannot create value for arg of type {type(arg)}")

def process_placeholder_node(self, node: Node) -> None:
ids = self.create_tensor_values(node)
ids = self.create_node_value(node)
if not self.is_param_node(node):
if isinstance(ids, int):
self.input_ids.append(ids)
else:
self.input_ids += ids

def process_getitem_node(self, node: Node) -> None:
# Find ValueList id from the collection node.
collection_node = node.all_input_nodes[0]
list_id = self.node_to_value_ids[collection_node]

# Extract the target Value id from ValueList.
valuelist_id = node.args[1]
value_id = self.values[list_id].value.items[valuelist_id]

# Map Node to Value id.
self.node_to_value_ids[node] = value_id

def process_call_function_node(self, node) -> None:
operator_call_args = []

Expand All @@ -238,12 +247,12 @@ def process_call_function_node(self, node) -> None:
else:
function_arg = schema_arg.default_value

# Create a value for each function argument. If the argument has been
# previously encountered, then use the existing value id.
# Create a Value for each function argument. If the argument has been
# previously encountered, then use the existing Value id.
operator_call_args.append(self.get_or_create_value_for(function_arg))

# Add output node
operator_call_args.append(self.create_tensor_values(node))
operator_call_args.append(self.create_node_value(node))

self.chain.append(
vk_graph_schema.OperatorCall(
Expand All @@ -253,7 +262,7 @@ def process_call_function_node(self, node) -> None:
)

def process_getattr_node(self, node: Node) -> None:
self.create_tensor_values(node)
self.create_node_value(node)

def process_output_node(self, node: Node) -> None:
for out_node in node.all_input_nodes:
Expand All @@ -269,7 +278,10 @@ def process_node(self, node: Node) -> None:
if node.op == "placeholder":
self.process_placeholder_node(node)
elif node.op == "call_function":
self.process_call_function_node(node)
if node.target == operator.getitem:
self.process_getitem_node(node)
else:
self.process_call_function_node(node)
elif node.op == "get_attr":
self.process_getattr_node(node)
elif node.op == "output":
Expand Down

0 comments on commit d0512b6

Please sign in to comment.