diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 84c9a132e2..12c8696c21 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator from typing import final, List, Optional import torch @@ -30,6 +31,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, + operator.getitem, ] return supported diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 4facc01a06..a88f3029d1 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator from typing import cast, List, Optional, Union import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema @@ -16,7 +17,7 @@ from torch.fx import Node _ScalarType = Union[bool, int, float] -_Argument = Union[Node, List[Node], _ScalarType, List[_ScalarType], str] +_Argument = Union[Node, List[Node], TensorSpec, _ScalarType, List[_ScalarType], str] class VkGraphBuilder: @@ -34,6 +35,7 @@ def __init__(self, program: ExportedProgram) -> None: @staticmethod def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: + # TODO(T182302927): Support more dtypes including float16, int(32|64). if torch_dtype == torch.float32: return vk_graph_schema.VkDataType.fp32 else: @@ -102,33 +104,20 @@ def get_param_tensor(self, node: Node) -> torch.Tensor: return tensor def maybe_add_constant_tensor(self, node: Node) -> int: - const_buffer_idx = -1 + constant_id = -1 if self.is_param_node(node): - const_buffer_idx = len(self.const_tensors) + constant_id = len(self.const_tensors) self.const_tensors.append(self.get_param_tensor(node)) - return const_buffer_idx - - def create_single_tensor_value(self, node: Node) -> int: - constant_id = self.maybe_add_constant_tensor(node) - - spec = node.meta.get("spec") - assert isinstance(spec, TensorSpec) - new_id = len(self.values) - if node not in self.node_to_value_ids: - self.node_to_value_ids[node] = new_id - else: - current_ids = self.node_to_value_ids[node] - if isinstance(current_ids, int): - current_ids = [current_ids, new_id] - else: - current_ids.append(new_id) + return constant_id + def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # Negative id indicates that this tensor will have its own dedicated memory. mem_obj_id = -1 if spec.mem_obj_id is not None: mem_obj_id = spec.mem_obj_id + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( value=vk_graph_schema.VkTensor( @@ -141,16 +130,23 @@ def create_single_tensor_value(self, node: Node) -> int: ) return new_id - def create_tensor_values(self, node: Node) -> int: + def create_node_value(self, node: Node) -> int: spec = node.meta.get("spec") if isinstance(spec, TensorSpec): - return self.create_single_tensor_value(node) + constant_id = self.maybe_add_constant_tensor(node) + new_id = self.create_tensor_value(spec, constant_id) + self.node_to_value_ids[node] = new_id + return new_id + elif isinstance(spec, tuple): + # Create a Value for each element in the tuple, wrap Values in a + # ValueList, and map the Node to the ValueList id. + new_id = self.create_value_list_value(spec) + self.node_to_value_ids[node] = new_id + return new_id else: - raise RuntimeError( - "Creating values for nodes with collection types is not supported yet." - ) + raise RuntimeError(f"Cannot create value for spec of type {type(spec)}") - def create_value_list_value(self, arg: List[Node]) -> int: + def create_value_list_value(self, arg: List[Node] | tuple) -> int: self.values.append( vk_graph_schema.VkValue( vk_graph_schema.ValueList( @@ -201,14 +197,15 @@ def create_string_value(self, string: str) -> int: def get_or_create_value_for(self, arg: _Argument): if isinstance(arg, Node): - # If the value has already been created, return the existing id + # If the Node has already been processed, return the existing id. if arg in self.node_to_value_ids: return self.node_to_value_ids[arg] - # Return id for a newly created value - return self.create_tensor_values(arg) + return self.create_node_value(arg) elif isinstance(arg, list) and isinstance(arg[0], Node): # pyre-ignore[6] return self.create_value_list_value(arg) + elif isinstance(arg, TensorSpec): + return self.create_tensor_value(arg) elif isinstance(arg, _ScalarType): return self.create_scalar_value(arg) elif isinstance(arg, list) and isinstance(arg[0], _ScalarType): @@ -220,13 +217,25 @@ def get_or_create_value_for(self, arg: _Argument): raise RuntimeError(f"Cannot create value for arg of type {type(arg)}") def process_placeholder_node(self, node: Node) -> None: - ids = self.create_tensor_values(node) + ids = self.create_node_value(node) if not self.is_param_node(node): if isinstance(ids, int): self.input_ids.append(ids) else: self.input_ids += ids + def process_getitem_node(self, node: Node) -> None: + # Find ValueList id from the collection node. + collection_node = node.all_input_nodes[0] + list_id = self.node_to_value_ids[collection_node] + + # Extract the target Value id from ValueList. + valuelist_id = node.args[1] + value_id = self.values[list_id].value.items[valuelist_id] + + # Map Node to Value id. + self.node_to_value_ids[node] = value_id + def process_call_function_node(self, node) -> None: operator_call_args = [] @@ -238,12 +247,12 @@ def process_call_function_node(self, node) -> None: else: function_arg = schema_arg.default_value - # Create a value for each function argument. If the argument has been - # previously encountered, then use the existing value id. + # Create a Value for each function argument. If the argument has been + # previously encountered, then use the existing Value id. operator_call_args.append(self.get_or_create_value_for(function_arg)) # Add output node - operator_call_args.append(self.create_tensor_values(node)) + operator_call_args.append(self.create_node_value(node)) self.chain.append( vk_graph_schema.OperatorCall( @@ -253,7 +262,7 @@ def process_call_function_node(self, node) -> None: ) def process_getattr_node(self, node: Node) -> None: - self.create_tensor_values(node) + self.create_node_value(node) def process_output_node(self, node: Node) -> None: for out_node in node.all_input_nodes: @@ -269,7 +278,10 @@ def process_node(self, node: Node) -> None: if node.op == "placeholder": self.process_placeholder_node(node) elif node.op == "call_function": - self.process_call_function_node(node) + if node.target == operator.getitem: + self.process_getitem_node(node) + else: + self.process_call_function_node(node) elif node.op == "get_attr": self.process_getattr_node(node) elif node.op == "output":