Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from pytorch:main #11

Merged
merged 3 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ if(EXECUTORCH_BUILD_PYBIND)
flatcc
portable_ops_lib
util
torch
${PYBIND_LINK_COREML}
${PYBIND_LINK_MPS}
${PYBIND_LINK_XNNPACK}
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ define_common_targets()
runtime.python_library(
name = "vulkan_preprocess",
srcs = [
"serialization/vulkan_graph_builder.py",
"serialization/vulkan_graph_schema.py",
"serialization/vulkan_graph_serialize.py",
"vulkan_preprocess.py",
Expand Down
22 changes: 22 additions & 0 deletions backends/vulkan/partitioner/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "vulkan_partitioner",
srcs = [
"vulkan_partitioner.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/exir:delegate",
"//executorch/exir:lib",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend:utils",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
],
)
63 changes: 63 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import final, List, Optional

import torch
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

from torch.fx.passes.operator_support import OperatorSupportBase


class VulkanSupportedOperators(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.floor_divide.default,
]
return supported


@final
class VulkanPartitioner(Partitioner):
def __init__(self, compile_spec: Optional[List[CompileSpec]] = None) -> None:
if compile_spec is None:
compile_spec = []
self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags
partition_tags = {}

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
VulkanSupportedOperators(),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
214 changes: 214 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

import torch

from executorch.exir.tensor import TensorSpec
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
from torch.export import ExportedProgram
from torch.fx import Node


class VkGraphBuilder:
def __init__(self, program: ExportedProgram) -> None:
self.program = program

self.chain = []
self.values = []
self.input_ids = []
self.output_ids = []
self.const_tensors = []

# Mapping from torch.fx.Node to VkValue id
self.node_to_value_ids = {}

@staticmethod
def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
if torch_dtype == torch.float32:
return vk_graph_schema.VkDataType.fp32
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

def is_constant(self, node: torch.fx.Node):
return (
node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants
)

def is_get_attr_node(self, node: torch.fx.Node) -> bool:
"""
Returns true if the given node is a get attr node for a tensor of the model
"""
return isinstance(node, torch.fx.Node) and node.op == "get_attr"

def is_param_node(self, node: torch.fx.Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
self.is_get_attr_node(node)
or is_param(self.program, node)
or is_buffer(self.program, node)
or self.is_constant(node)
)

def get_constant(self, node: torch.fx.Node) -> Optional[torch.Tensor]:
"""
Returns the constant associated with the given node in the exported program.
Returns None if the node is not a constant within the exported program
"""
if self.is_constant(node):
constant_name = (
self.program.graph_signature.inputs_to_lifted_tensor_constants[
node.name
]
)
if constant_name in self.program.constants:
return self.program.constants[constant_name]
else:
return None

return None

def get_param_tensor(self, node: torch.fx.Node) -> torch.Tensor:
tensor = None
if node is None:
raise RuntimeError("node is None")
elif is_param(self.program, node):
tensor = get_param(self.program, node)
elif is_buffer(self.program, node):
tensor = get_buffer(self.program, node)
elif self.is_constant(node):
tensor = self.get_constant(node)
elif self.is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
tensor = getattr(node.graph.owning_module, node.target)
except AttributeError:
tensor = getattr(self.program.graph_module, node.target)
else:
raise RuntimeError(f"unsupported param type, {node.op}.")

assert tensor is not None
return tensor

def maybe_add_constant_tensor(self, node: Node) -> int:
const_buffer_idx = -1
if self.is_param_node(node):
const_buffer_idx = len(self.const_tensors)
self.const_tensors.append(self.get_param_tensor(node))

return const_buffer_idx

def create_single_vk_value(self, node: Node) -> int:
constant_id = self.maybe_add_constant_tensor(node)

spec = node.meta.get("spec")
assert isinstance(spec, TensorSpec)
new_id = len(self.values)
if node not in self.node_to_value_ids:
self.node_to_value_ids[node] = new_id
else:
current_ids = self.node_to_value_ids[node]
if isinstance(current_ids, int):
current_ids = [current_ids, new_id]
else:
current_ids.append(new_id)

# Negative id indicates that this tensor will have its own dedicated memory.
mem_obj_id = -1
if spec.mem_obj_id is not None:
mem_obj_id = spec.mem_obj_id

self.values.append(
vk_graph_schema.VkValue(
value=vk_graph_schema.VkTensor(
datatype=self.get_vk_datatype(spec.dtype),
dims=spec.shape,
constant_id=constant_id,
mem_obj_id=mem_obj_id,
)
)
)
return new_id

def create_vk_values_for(self, node: Node):
spec = node.meta.get("spec")
if isinstance(spec, TensorSpec):
return self.create_single_vk_value(node)
else:
raise RuntimeError(
"Creating values for nodes with collection types is not supported yet."
)

def process_placeholder_node(self, node: Node) -> None:
ids = self.create_vk_values_for(node)
if not self.is_param_node(node):
if isinstance(ids, int):
self.input_ids.append(ids)
else:
self.input_ids += ids

def process_call_function_node(self, node) -> None:
args = []
# Add input nodes
for inp_node in node.all_input_nodes:
if inp_node not in self.node_to_value_ids:
raise AssertionError(
"Cannot find input to current node in node_to_value_ids. This means "
"this node is being serialized before its input which is not allowed."
)
args.append(self.node_to_value_ids[inp_node])
# Add output node
args.append(self.create_vk_values_for(node))

self.chain.append(
vk_graph_schema.OperatorCall(
name=node.target.__name__,
args=args,
),
)

def process_getattr_node(self, node: Node) -> None:
self.create_vk_values_for(node)

def process_output_node(self, node: Node) -> None:
if node.all_input_nodes[0] not in self.node_to_value_ids:
raise AssertionError(
"Cannot find input to output node in node_to_value_ids. This means the "
"output node is being serialized before its corresponding internal node "
"which is not allowed."
)
self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]])

def process_node(self, node: Node) -> None:
if node.op == "placeholder":
self.process_placeholder_node(node)
elif node.op == "call_function":
self.process_call_function_node(node)
elif node.op == "get_attr":
self.process_getattr_node(node)
elif node.op == "output":
self.process_output_node(node)
else:
raise AssertionError(f"Unsupported node op: {node.op}")

def build_graph(self) -> vk_graph_schema.VkGraph:
for node in self.program.graph_module.graph.nodes:
self.process_node(node)

return vk_graph_schema.VkGraph(
version="0",
chain=self.chain,
values=self.values,
input_ids=self.input_ids,
output_ids=self.output_ids,
constants=[],
shaders=[],
)
2 changes: 1 addition & 1 deletion backends/vulkan/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ python_unittest(
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/extension/pytree:pylib",
"//executorch/kernels/portable:custom_ops_generated_lib",
Expand Down
Loading
Loading