diff --git a/CMakeLists.txt b/CMakeLists.txt index 1e2357adeb..00f6508ad3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,6 +431,7 @@ if(EXECUTORCH_BUILD_PYBIND) flatcc portable_ops_lib util + torch ${PYBIND_LINK_COREML} ${PYBIND_LINK_MPS} ${PYBIND_LINK_XNNPACK} diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 8f77210b89..86733510a3 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -8,6 +8,7 @@ define_common_targets() runtime.python_library( name = "vulkan_preprocess", srcs = [ + "serialization/vulkan_graph_builder.py", "serialization/vulkan_graph_schema.py", "serialization/vulkan_graph_serialize.py", "vulkan_preprocess.py", diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS new file mode 100644 index 0000000000..afd6f21632 --- /dev/null +++ b/backends/vulkan/partitioner/TARGETS @@ -0,0 +1,22 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "vulkan_partitioner", + srcs = [ + "vulkan_partitioner.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/backends/vulkan:vulkan_preprocess", + "//executorch/exir:delegate", + "//executorch/exir:lib", + "//executorch/exir/backend:partitioner", + "//executorch/exir/backend:utils", + "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", + ], +) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py new file mode 100644 index 0000000000..4eb0ff47e2 --- /dev/null +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import final, List, Optional + +import torch +from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.exported_program import ExportedProgram +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + +from torch.fx.passes.operator_support import OperatorSupportBase + + +class VulkanSupportedOperators(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + supported = node.op == "call_function" and node.target in [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.floor_divide.default, + ] + return supported + + +@final +class VulkanPartitioner(Partitioner): + def __init__(self, compile_spec: Optional[List[CompileSpec]] = None) -> None: + if compile_spec is None: + compile_spec = [] + self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + # Run the CapabilityBasedPartitioner to return the largest possible + # subgraphs containing the nodes with the tags + partition_tags = {} + + capability_partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + VulkanSupportedOperators(), + allows_single_node_partition=True, + ) + partition_list = capability_partitioner.propose_partitions() + for partition in partition_list: + for node in partition.nodes: + tag = f"tag{partition.id}" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py new file mode 100644 index 0000000000..68e54c2bc3 --- /dev/null +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema + +import torch + +from executorch.exir.tensor import TensorSpec +from torch._export.utils import get_buffer, get_param, is_buffer, is_param +from torch.export import ExportedProgram +from torch.fx import Node + + +class VkGraphBuilder: + def __init__(self, program: ExportedProgram) -> None: + self.program = program + + self.chain = [] + self.values = [] + self.input_ids = [] + self.output_ids = [] + self.const_tensors = [] + + # Mapping from torch.fx.Node to VkValue id + self.node_to_value_ids = {} + + @staticmethod + def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: + if torch_dtype == torch.float32: + return vk_graph_schema.VkDataType.fp32 + else: + raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") + + def is_constant(self, node: torch.fx.Node): + return ( + node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants + ) + + def is_get_attr_node(self, node: torch.fx.Node) -> bool: + """ + Returns true if the given node is a get attr node for a tensor of the model + """ + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + def is_param_node(self, node: torch.fx.Node) -> bool: + """ + Check if the given node is a parameter within the exported program + """ + return ( + self.is_get_attr_node(node) + or is_param(self.program, node) + or is_buffer(self.program, node) + or self.is_constant(node) + ) + + def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]: + """ + Returns the constant associated with the given node in the exported program. + Returns None if the node is not a constant within the exported program + """ + if self.is_constant(node): + constant_name = ( + self.program.graph_signature.inputs_to_lifted_tensor_constants[ + node.name + ] + ) + if constant_name in self.program.constants: + return self.program.constants[constant_name] + else: + return None + + return None + + def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor: + tensor = None + if node is None: + raise RuntimeError("node is None") + elif is_param(self.program, node): + tensor = get_param(self.program, node) + elif is_buffer(self.program, node): + tensor = get_buffer(self.program, node) + elif self.is_constant(node): + tensor = self.get_constant(node) + elif self.is_get_attr_node(node): + # This is a hack to support both lifted and unlifted graph + try: + tensor = getattr(node.graph.owning_module, node.target) + except AttributeError: + tensor = getattr(self.program.graph_module, node.target) + else: + raise RuntimeError(f"unsupported param type, {node.op}.") + + assert tensor is not None + return tensor + + def maybe_add_constant_tensor(self, node: Node) -> int: + const_buffer_idx = -1 + if self.is_param_node(node): + const_buffer_idx = len(self.const_tensors) + self.const_tensors.append(self.get_param_tensor(node)) + + return const_buffer_idx + + def create_single_vk_value(self, node: Node) -> int: + constant_id = self.maybe_add_constant_tensor(node) + + spec = node.meta.get("spec") + assert isinstance(spec, TensorSpec) + new_id = len(self.values) + if node not in self.node_to_value_ids: + self.node_to_value_ids[node] = new_id + else: + current_ids = self.node_to_value_ids[node] + if isinstance(current_ids, int): + current_ids = [current_ids, new_id] + else: + current_ids.append(new_id) + + # Negative id indicates that this tensor will have its own dedicated memory. + mem_obj_id = -1 + if spec.mem_obj_id is not None: + mem_obj_id = spec.mem_obj_id + + self.values.append( + vk_graph_schema.VkValue( + value=vk_graph_schema.VkTensor( + datatype=self.get_vk_datatype(spec.dtype), + dims=spec.shape, + constant_id=constant_id, + mem_obj_id=mem_obj_id, + ) + ) + ) + return new_id + + def create_vk_values_for(self, node: Node): + spec = node.meta.get("spec") + if isinstance(spec, TensorSpec): + return self.create_single_vk_value(node) + else: + raise RuntimeError( + "Creating values for nodes with collection types is not supported yet." + ) + + def process_placeholder_node(self, node: Node) -> None: + ids = self.create_vk_values_for(node) + if not self.is_param_node(node): + if isinstance(ids, int): + self.input_ids.append(ids) + else: + self.input_ids += ids + + def process_call_function_node(self, node) -> None: + args = [] + # Add input nodes + for inp_node in node.all_input_nodes: + if inp_node not in self.node_to_value_ids: + raise AssertionError( + "Cannot find input to current node in node_to_value_ids. This means " + "this node is being serialized before its input which is not allowed." + ) + args.append(self.node_to_value_ids[inp_node]) + # Add output node + args.append(self.create_vk_values_for(node)) + + self.chain.append( + vk_graph_schema.OperatorCall( + name=node.target.__name__, + args=args, + ), + ) + + def process_getattr_node(self, node: Node) -> None: + self.create_vk_values_for(node) + + def process_output_node(self, node: Node) -> None: + if node.all_input_nodes[0] not in self.node_to_value_ids: + raise AssertionError( + "Cannot find input to output node in node_to_value_ids. This means the " + "output node is being serialized before its corresponding internal node " + "which is not allowed." + ) + self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]]) + + def process_node(self, node: Node) -> None: + if node.op == "placeholder": + self.process_placeholder_node(node) + elif node.op == "call_function": + self.process_call_function_node(node) + elif node.op == "get_attr": + self.process_getattr_node(node) + elif node.op == "output": + self.process_output_node(node) + else: + raise AssertionError(f"Unsupported node op: {node.op}") + + def build_graph(self) -> vk_graph_schema.VkGraph: + for node in self.program.graph_module.graph.nodes: + self.process_node(node) + + return vk_graph_schema.VkGraph( + version="0", + chain=self.chain, + values=self.values, + input_ids=self.input_ids, + output_ids=self.output_ids, + constants=[], + shaders=[], + ) diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index e71b6a0ec9..608b3137c3 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -15,8 +15,8 @@ python_unittest( deps = [ "//caffe2:torch", "//executorch/backends/vulkan:vulkan_preprocess", + "//executorch/backends/vulkan/partitioner:vulkan_partitioner", "//executorch/exir:lib", - "//executorch/exir/backend:backend_api", "//executorch/extension/pybindings:portable_lib", # @manual "//executorch/extension/pytree:pylib", "//executorch/kernels/portable:custom_ops_generated_lib", diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 42bd5ec5b8..6f9b4ee16c 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -8,14 +8,13 @@ import unittest from typing import Tuple -import executorch.exir as exir import torch -# import the vulkan backend implementation +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend -from executorch.exir import ExecutorchProgram -from executorch.exir.backend.backend_api import to_backend +from executorch.exir import EdgeProgramManager, to_edge +from torch.export import export, ExportedProgram ctypes.CDLL("libvulkan.so.1") @@ -51,7 +50,7 @@ def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03) def lower_module_and_test_output( self, - module: torch.nn.Module, + model: torch.nn.Module, sample_inputs: Tuple[torch.Tensor], atol=1e-03, rtol=1e-01, @@ -61,36 +60,23 @@ def lower_module_and_test_output( the given sample inputs. It then runs the lowered module and compares its outputs with the outputs of the eager module. """ - edgeir_m = exir.capture(module, sample_inputs, exir.CaptureConfig()).to_edge() - lowered_module = to_backend("VulkanBackend", edgeir_m.exported_program, []) + program: ExportedProgram = export(model, sample_inputs) + edge_program: EdgeProgramManager = to_edge(program) + edge_program = edge_program.to_backend(VulkanPartitioner()) - class WrappedModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.one_module = lowered_module - - def forward(self, *args): - return self.one_module(*args) + executorch_program = edge_program.to_executorch() - executorch_program: ExecutorchProgram = ( - exir.capture(WrappedModule(), sample_inputs, exir.CaptureConfig()) - .to_edge() - .to_executorch() - ) - - # Assert the backend name is vulkan self.assertEqual( - executorch_program.program.execution_plan[0].delegates[0].id, + executorch_program.executorch_program.execution_plan[0].delegates[0].id, VulkanBackend.__name__, ) - # Test the model with executor executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. inputs_flattened, _ = tree_flatten(sample_inputs) model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) - ref_output = module(*sample_inputs) + ref_output = model(*sample_inputs) self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol) @@ -192,26 +178,6 @@ def forward(self, x, y): self.lower_module_and_test_output(div_module, model_inputs) - def test_vulkan_backend_floor_div(self): - class FloorDivModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - z = x // y - return z - - floor_div_module = FloorDivModule() - model_inputs = ( - torch.rand(size=(2, 3), dtype=torch.float32) * 10.0, - torch.rand(size=(2, 3), dtype=torch.float32) + 1.0, - ) - - # absolute tolerance is 1 because of flooring - self.lower_module_and_test_output( - floor_div_module, model_inputs, atol=1.0 + 1e-03 - ) - def test_vulkan_backend_arithmetic(self): class ArithmeticModule(torch.nn.Module): def __init__(self): @@ -249,3 +215,23 @@ def forward(self, x, y): ) self.lower_module_and_test_output(pow_module, model_inputs) + + def test_vulkan_backend_partial(self): + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.offset_1 = self.weight = torch.rand( + size=(2, 10), dtype=torch.float32 + ) + self.offset_2 = self.weight = torch.rand( + size=(2, 10), dtype=torch.float32 + ) + + def forward(self, x): + return self.linear(x + self.offset_1) - self.offset_2 + + model = SimpleModel() + model_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),) + + self.lower_module_and_test_output(model, model_inputs) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index e3667e9afe..293d114e8d 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -7,6 +7,8 @@ from typing import final, List import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema + +from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( serialize_vulkan_graph, ) @@ -21,9 +23,7 @@ from executorch.exir.passes import MemoryPlanningPass, SpecPropPass from executorch.exir.program._program import _copy_module -from executorch.exir.tensor import TensorSpec from torch import dtype, float32 -from torch.fx import Node DEFAULT_DEBUG_HANDLE = 65535 @@ -44,62 +44,13 @@ def preprocess( # noqa: C901 program: ExportedProgram, module_compile_spec: List[CompileSpec], ) -> PreprocessResult: - vk_chain = [] - vk_values = [] - vk_input_ids = [] - vk_output_ids = [] - const_tensors = [] - - # Mapping from graph Node to schema VkValue. - node_to_value_ids = {} - - def create_single_vk_value(node: Node, constant_id: int = -1) -> int: - spec = node.meta.get("spec") - assert isinstance(spec, TensorSpec) - new_id = len(vk_values) - if node not in node_to_value_ids: - node_to_value_ids[node] = new_id - else: - current_ids = node_to_value_ids[node] - if isinstance(current_ids, int): - current_ids = [current_ids, new_id] - else: - current_ids.append(new_id) - - # Negative id indicates that this tensor will have its own dedicated memory. - mem_obj_id = -1 - if spec.mem_obj_id is not None: - mem_obj_id = spec.mem_obj_id - - vk_values.append( - vk_graph_schema.VkValue( - value=vk_graph_schema.VkTensor( - datatype=VulkanBackend.get_vk_datatype(spec.dtype), - dims=spec.shape, - constant_id=constant_id, - mem_obj_id=mem_obj_id, - ) - ) - ) - return new_id - - def create_vk_values_for(node: Node, constant_id: int = -1): - spec = node.meta.get("spec") - - if isinstance(spec, TensorSpec): - return create_single_vk_value(node, constant_id) - else: - ids = [] - for _ in spec: - ids.append(create_single_vk_value(node, constant_id)) - return ids - passes = [ SpecPropPass(), MemoryPlanningPass("greedy"), ] new_gm = program.graph_module + for p in passes: # This is a workaround to allow the memory planning pass to work without # having to first apply ToOutVarPass(). See the `greedy()` function in @@ -110,62 +61,14 @@ def create_vk_values_for(node: Node, constant_id: int = -1): new_gm_res = p(new_gm) assert new_gm_res is not None new_gm = new_gm_res.graph_module - _copy_module(program.graph_module, new_gm) - for node in program.graph_module.graph.nodes: - if node.op == "placeholder": - # Input - ids = create_vk_values_for(node) - if isinstance(ids, int): - vk_input_ids.append(ids) - else: - vk_input_ids += ids - elif node.op == "call_function": - # Op - if ( - node.all_input_nodes[0] not in node_to_value_ids - or node.all_input_nodes[1] not in node_to_value_ids - ): - raise AssertionError( - "Cannot find input(s) for current node in node_to_value_ids. This means this node is being serialized before its input(s) which is not allowed." - ) - vk_chain.append( - vk_graph_schema.OperatorCall( - name=node.target.__name__, - args=[ - node_to_value_ids[node.all_input_nodes[0]], - node_to_value_ids[node.all_input_nodes[1]], - create_vk_values_for(node), - ], - ), - ) - elif node.op == "get_attr": - constant_id = len(const_tensors) - const_tensors.append( - getattr(node.graph.owning_module, node.target).contiguous() - ) - - create_vk_values_for(node, constant_id) + _copy_module(program.graph_module, new_gm) - elif node.op == "output": - if node.all_input_nodes[0] not in node_to_value_ids: - raise AssertionError( - "Cannot find input to output node in node_to_value_ids. This means the output node is being serialized before its corresponding internal node which is not allowed." - ) - vk_output_ids.append(node_to_value_ids[node.all_input_nodes[0]]) - else: - raise RuntimeError(f"Unsupported op, {node.op}, in Vulkan Preprocess") + graph_builder = VkGraphBuilder(program) + vk_graph = graph_builder.build_graph() - # Raw objects (constants and shaders) are populated in the next line's method. - vk_graph = vk_graph_schema.VkGraph( - version="0", - chain=vk_chain, - values=vk_values, - input_ids=vk_input_ids, - output_ids=vk_output_ids, - constants=[], - shaders=[], - ) return PreprocessResult( - processed_bytes=serialize_vulkan_graph(vk_graph, const_tensors, []), + processed_bytes=serialize_vulkan_graph( + vk_graph, graph_builder.const_tensors, [] + ), )