diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index cf50f170cf..ed3d847933 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -16,6 +16,20 @@ runtime.python_library( ], ) +runtime.python_library( + name = "int4_weight_only_quantizer", + srcs = [ + "int4_weight_only_quantizer.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//executorch/backends/vulkan:custom_ops_lib", + "//pytorch/ao:torchao", + ] +) + runtime.python_library( name = "remove_local_scalar_dense", srcs = ["remove_local_scalar_dense_ops.py"], @@ -30,17 +44,18 @@ runtime.python_library( ) runtime.python_library( - name = "int4_weight_only_quantizer", - srcs = [ - "int4_weight_only_quantizer.py", - ], + name = "tag_memory_meta_pass", + srcs = ["tag_memory_meta_pass.py"], visibility = [ "//executorch/backends/...", ], deps = [ - "//executorch/backends/vulkan:custom_ops_lib", - "//pytorch/ao:torchao", - ] + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/backends/vulkan:utils_lib", + "//executorch/backends/vulkan/serialization:lib", + ], ) runtime.python_library( @@ -56,5 +71,6 @@ runtime.python_library( ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_local_scalar_dense", + ":tag_memory_meta_pass" ] ) diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index cfdb7c6eee..8823553ab1 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -5,9 +5,11 @@ from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) +from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "RemoveLocalScalarDenseOpsTransform", + "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py new file mode 100644 index 0000000000..fd0bd3648e --- /dev/null +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from copy import deepcopy +from typing import Set + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.op_registry import get_op_features, has_impl + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult + +from torch._subclasses.fake_tensor import FakeTensor + +from torch.fx.passes.tools_common import NodeList +from torch.fx.passes.utils.fuser_utils import topo_sort + +logger: logging.Logger = logging.getLogger("") +logger.setLevel(logging.INFO) + + +def set_memory_metadata( + node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout +) -> None: + utils.set_node_spec_attr(node, "vk_storage_type", storage) + utils.set_node_spec_attr(node, "vk_memory_layout", layout) + + +class TagMemoryMetaPass(ExportPass): + """ + There are a variety of ways that tensors can be represented in Vulkan. The two main + descriptors for how a tensor is laid out in memory is: + + 1. Storage Type (buffer or texture) + 2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.) + + Due to the differences between buffers and textures, and the differences between + different memory layouts, an implementation for an operator may only support a + specific set of (storage type, memory layout) combinations. + + Furthermore, if an operator implementation supports multiple (storage type, memory + layout) combinations, there may be a "preferred" setting which results in optimal + performance. + + This pass is responsible for ensuring that all tensors participating in an operator + call have a valid/optimal (storage type, memory layout) setting, and insert + transition operators to transfer input tensors to the correct memory settings when + necessary. + """ + + def __init__( + self, + texture_limits: utils.ImageExtents, + default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, + default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, + ): + super().__init__() + self.default_storage: VkStorageType = default_storage_type + self.default_layout: VkMemoryLayout = default_memory_layout + self.texture_limits = texture_limits + + def propose_node_storage( + self, + node: torch.fx.Node, + ) -> VkStorageType: + """ + Uses the operator registry to determine the storage type that should be used for + a given node. The storage type is determined with the following priorities: + 1. In some cases, a tensor involved in the computation may be too large to be + represented as a texture. If this is the case, the node is "opinionated" and + buffer representation must be used. + 1. If the operator called by the node indicates an optimal storage type, or only + supports a single storage type, use that storage type. If either is true, + then the node is considered to be opinionated as well. If multiple storage + and no preferred storage type is indicated, then the node is not opinionated; + go to the next step. + 2. If the node's arguments already have memory metadata annotations, then + preserve the settings of the first argument. Otherwise, proceed to the next + step. + 3. Recursively search the node's uses to see if any subsequent uses are + opinionated; inherit the settings of the first opinionated node. If no + opinionated user can be found, then proceed to the last step. + 4. Use the default storage type setting. + """ + # The node may have an input/output tensor that is too big to be stored in a + # texture. In this case, buffer storage must be used. Note that the partitioner + # has already checked for the fact that buffer storage is supported by the + # operator. + if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0: + return VkStorageType.BUFFER + + valid_storage_types: Set[VkStorageType] = utils.all_storage_types + + # pyre-ignore + if has_impl(node.target): + # pyre-ignore + features = get_op_features(node.target) + valid_storage_types = features.supported_storage_types() + storage = features.propose_storage_type() + if storage is not None: + return storage + + for arg in node.args: + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta["val"], FakeTensor + ): + storage = utils.get_node_storage_type(arg) + if storage is not None and storage in valid_storage_types: + return storage + + # If no storage type has been resolved yet, assume the optimal storage type of + # the first opinionated user. This search is recursive. + for user in node.users: + optimal_storage = self.propose_node_storage(user) + if optimal_storage is not None: + return optimal_storage + + if self.default_storage in valid_storage_types: + return self.default_storage + else: + return next(iter(valid_storage_types)) + + def propose_node_layout( + self, + node: torch.fx.Node, + storage: VkStorageType, + ) -> VkMemoryLayout: + """ + Performs the same steps as propose_node_storage, but detects the memory layout + that should be used for the specific storage type. The same prioritization logic + is applied. + """ + valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts + # pyre-ignore + if has_impl(node.target): + # pyre-ignore + features = get_op_features(node.target) + valid_layouts = features.supported_memory_layouts(storage) + layout = features.propose_memory_layout(storage) + if layout is not None: + return layout + + for arg in node.args: + if isinstance(arg, torch.fx.Node) and isinstance( + arg.meta["val"], FakeTensor + ): + layout = utils.get_node_memory_layout(arg) + if layout is not None and layout in valid_layouts: + return layout + + # If no storage type has been resolved yet, assume the optimal storage type of + # the first opinionated user. This search is recursive. + for user in node.users: + optimal_storage = self.propose_node_layout(user, storage) + if optimal_storage is not None: + return optimal_storage + + # As a last resort, return the default storage type that should be used. + if self.default_layout in valid_layouts: + return self.default_layout + else: + return next(iter(valid_layouts)) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) + + for node in sorted_nodes: + if not isinstance(node.meta["val"], FakeTensor): + continue + + if node.target == exir_ops.edge.et_vk.prepack.default: + continue + + storage = self.propose_node_storage(node) + layout = self.propose_node_layout(node, storage) + + set_memory_metadata(node, storage, layout) + + inserting_transitions_for_node = False + for i, arg in enumerate(node.args): + if not isinstance(arg, torch.fx.Node): + continue + if not isinstance(arg.meta["val"], FakeTensor): + continue + + arg_storage = utils.get_node_storage_type(arg) + arg_layout = utils.get_node_memory_layout(arg) + + if arg_storage is None: + utils.set_node_spec_attr(arg, "vk_storage_type", storage) + arg_storage = storage + if arg_layout is None: + utils.set_node_spec_attr(arg, "vk_memory_layout", layout) + arg_layout = layout + + if arg_storage == storage and arg_layout == layout: + continue + + if not inserting_transitions_for_node: + inserting_transitions_for_node = True + logger.info( + f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" + ) + + logger.info( + f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" + ) + + # Insert a clone node to copy the original tensor to a tensor with the + # desired storage type and memory layout. + with graph_module.graph.inserting_before(node): + clone_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.clone.default, + (arg,), + ) + clone_node.meta["val"] = arg.meta["val"] + clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) + clone_node.meta["spec"].const = False + set_memory_metadata(clone_node, storage, layout) + arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) + + return PassResult(graph_module, True) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index c851eeb4da..f1fd47fb2b 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -94,9 +94,11 @@ def op_node_is_compatible( # If there are no valid texture memory layouts, then buffer storage must be # supported by the operator implementation. if len(valid_texture_layouts) == 0: - # TODO: once memory metadata tagging pass is implemented, check that the - # op impl supports buffers instead - return False, "requires buffer representation" + compatible = VkStorageType.BUFFER in features.supported_storage_types() + reason = "op is compatible" + if not compatible: + reason = "op requires buffers which is not supported by op impl" + return compatible, reason op_available_layouts = features.supported_memory_layouts( VkStorageType.TEXTURE_3D diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index bc77bc40cf..8144747212 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -12,6 +12,11 @@ import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema import torch + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.utils import ( is_constant, is_get_attr_node, @@ -169,6 +174,15 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: if spec.mem_obj_id is not None: mem_obj_id = spec.mem_obj_id + storage_type = VkStorageType.DEFAULT_STORAGE + memory_layout = VkMemoryLayout.DEFAULT_LAYOUT + if hasattr(spec, "vk_storage_type"): + # pyre-ignore[16] + storage_type = spec.vk_storage_type + if hasattr(spec, "vk_memory_layout"): + # pyre-ignore[16] + memory_layout = spec.vk_memory_layout + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( @@ -177,6 +191,8 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: dims=spec.shape, constant_id=constant_id, mem_obj_id=mem_obj_id, + storage_type=storage_type, + memory_layout=memory_layout, ) ) ) diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 8197f705b5..35113bc623 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -37,6 +37,9 @@ class VkStorageType(IntEnum): TEXTURE_2D = 2 DEFAULT_STORAGE = 255 + def __str__(self) -> str: + return self.name + class VkMemoryLayout(IntEnum): TENSOR_WIDTH_PACKED = 0 @@ -44,6 +47,9 @@ class VkMemoryLayout(IntEnum): TENSOR_CHANNELS_PACKED = 2 DEFAULT_LAYOUT = 255 + def __str__(self) -> str: + return self.name + @dataclass class VkTensor: diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 9785b34951..9521bcacdb 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -223,6 +223,8 @@ def define_common_targets(is_fbcode = False): ], deps = [ "//caffe2:torch", + "//executorch/exir:tensor", + "//executorch/backends/vulkan/serialization:lib", ] ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 4264e94271..2e9fbba01c 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from enum import IntEnum -from typing import Set, Tuple +from typing import Optional, Set, Tuple import torch @@ -13,6 +13,9 @@ VkMemoryLayout, VkStorageType, ) + +from executorch.exir.tensor import TensorSpec + from torch._export.utils import is_buffer, is_param from torch._subclasses.fake_tensor import FakeTensor @@ -170,3 +173,43 @@ def possible_node_memory_layouts( ) return valid_layouts + + +## +## TensorSpec Utils +## + + +def set_node_spec_attr(node: torch.fx.Node, attr: str, value): + assert "spec" in node.meta + spec = node.meta["spec"] + if isinstance(spec, TensorSpec): + setattr(spec, attr, value) + elif isinstance(spec, list) or isinstance(spec, tuple): + for s in spec: + assert isinstance(s, TensorSpec) + setattr(s, attr, value) + else: + raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}") + + +def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True): + assert "spec" in node.meta + spec = node.meta["spec"] + if isinstance(spec, TensorSpec): + return getattr(spec, attr) if hasattr(spec, attr) else None + elif isinstance(spec, list) or isinstance(spec, tuple): + if return_first: + return getattr(spec[0], attr) if hasattr(spec, attr) else None + else: + return [getattr(s, attr) if hasattr(s, attr) else None for s in spec] + else: + raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}") + + +def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]: + return get_node_spec_attr(node, "vk_storage_type") + + +def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: + return get_node_spec_attr(node, "vk_memory_layout") diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 96eee198f4..f0a5fd6725 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -6,7 +6,9 @@ # pyre-strict -from typing import final, List +from typing import Any, Dict, final, List + +import executorch.backends.vulkan.utils as utils from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.transforms.fuse_batch_norm_with_conv import ( @@ -20,9 +22,14 @@ from executorch.backends.vulkan._passes import ( insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, + TagMemoryMetaPass, ) from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( serialize_vulkan_graph, ) @@ -78,6 +85,24 @@ def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: return program +def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: + options = {} + for spec in compile_specs: + if spec.key == "storage_type_override": + options[spec.key] = VkStorageType( + int.from_bytes(spec.value, byteorder="little") + ) + if spec.key == "memory_layout_override": + options[spec.key] = VkMemoryLayout( + int.from_bytes(spec.value, byteorder="little") + ) + if spec.key in {"texture_limits_x", "texture_limits_y", "texture_limits_z"}: + options[spec.key] = int.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored + + return options + + @final class VulkanBackend(BackendDetails): @classmethod @@ -87,6 +112,25 @@ def preprocess( # noqa: C901 program: ExportedProgram, module_compile_spec: List[CompileSpec], ) -> PreprocessResult: + compile_options = parse_compile_spec(module_compile_spec) + limits_x = compile_options.get( + "texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0] + ) + limits_y = compile_options.get( + "texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1] + ) + limits_z = compile_options.get( + "texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2] + ) + texture_limits = (limits_x, limits_y, limits_z) + + default_storage_type = compile_options.get( + "storage_type_override", VkStorageType.TEXTURE_3D + ) + default_memory_layout = compile_options.get( + "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED + ) + program = unsafe_remove_auto_functionalized_pass(program) # First, apply passes that fuse/remove operators to consolidate the graph @@ -122,10 +166,31 @@ def preprocess( # noqa: C901 ], ) + # Optionally apply the memory metadata tagging pass, which will insert storage + # type and memory layout transition nodes to ensure that all tensor arguments + # to an operator is in a supported or optimal configuration. If this pass is not + # applied, there will be a risk that some operators recieve arguments with + # memory settings that are not supported by the implementation. + if not compile_options.get("skip_tag_memory_metadata", False): + program = apply_passes( + program, + [ + TagMemoryMetaPass( + texture_limits, + default_storage_type=default_storage_type, + default_memory_layout=default_memory_layout, + ), + ], + ) + # Finally, apply dynamic shape passes and memory planning pass. These passes # must be applied only when the graph structure is finalized. program = apply_passes( - program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()] + program, + [ + ConstraintBasedSymShapeEvalPass(), + MemoryPlanningPass(), + ], ) graph_builder = VkGraphBuilder( diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index f3822b6866..23b3589c2a 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -622,7 +622,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 partitioners.append( get_vulkan_partitioner( args.dtype_override, - args.quantization_mode, + args.enable_dynamic_shape, ) ) modelname = f"vulkan_{modelname}" diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index d966de9a25..6f4b95e3d0 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -32,7 +32,7 @@ def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True): def get_vulkan_partitioner( - dtype_override: Optional[str] = None, quantization_mode: Optional[str] = None + dtype_override: Optional[str] = None, enable_dynamic_shape: bool = False ): assert ( dtype_override == "fp32" or dtype_override is None @@ -41,7 +41,7 @@ def get_vulkan_partitioner( VulkanPartitioner, ) - return VulkanPartitioner({"require_dynamic_shapes": True}) + return VulkanPartitioner({"require_dynamic_shapes": enable_dynamic_shape}) def get_mps_partitioner(use_kv_cache: bool = False):