Skip to content

Commit

Permalink
Improve schema names and comments (pytorch#2035)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2035

This change addresses a lot of nitpicks which improves readability (at least for me):
- Improve schema comments to be sentences and limit line length.
- Use the suffix `_id` and `_ids` for an index and list of indices, respectively.
- Order of tables matches their usage in `table VkGraph`.
- Improve understanding of python dict contents via name change: `node_vk_value_ids` -> `node_to_value_ids`.

Note we will remove `VkArithmeticOpType` soon, so we don't bother improving its readability.

Reviewed By: SS-JIA

Differential Revision: D53982444

fbshipit-source-id: 7bf345d93d1450bee89713a55470da8cb14a2155
  • Loading branch information
jorgep31415 authored and facebook-github-bot committed Feb 22, 2024
1 parent d98741c commit 99c70f9
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 95 deletions.
16 changes: 8 additions & 8 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ class VulkanBackend final : public PyTorchBackendInterface {
}

api::ScalarType get_scalar_type(
const vkgraph::VkDatatype& vk_datatype) const {
const vkgraph::VkDataType& vk_datatype) const {
switch (vk_datatype) {
case (vkgraph::VkDatatype::vk_datatype_fp32): {
case (vkgraph::VkDataType::fp32): {
return api::kFloat;
}
}
Expand All @@ -111,9 +111,9 @@ class VulkanBackend final : public PyTorchBackendInterface {
VkTensorPtr vk_tensor = vk_value->value();

ET_CHECK_MSG(
vk_tensor->constant_buffer_idx() >= 0,
"Only constant buffers are supported when adding tensors to compute graph (indicated by constant_buffer_idx == 0), but got constant_buffer_idx of %d",
vk_tensor->constant_buffer_idx());
vk_tensor->constant_id() >= 0,
"Only constant buffers are supported when adding tensors to compute graph (indicated by constant_id < 0), but got constant_id of %d",
vk_tensor->constant_id());

const api::ScalarType& tensor_dtype =
get_scalar_type(vk_tensor->datatype());
Expand All @@ -123,7 +123,7 @@ class VulkanBackend final : public PyTorchBackendInterface {
tensor_dims_fb->cbegin(), tensor_dims_fb->cend());

const uint8_t* tensor_data = getConstantDataPtr(
flatbuffer_graph, vk_tensor->constant_buffer_idx(), constant_data);
flatbuffer_graph, vk_tensor->constant_id(), constant_data);

const ValueRef value_ref = compute_graph->add_tensorref(
tensor_dims_vector, tensor_dtype, tensor_data);
Expand Down Expand Up @@ -211,11 +211,11 @@ class VulkanBackend final : public PyTorchBackendInterface {
VkTensorPtr input_vk_tensor = input_vk_value->value();

ET_CHECK_MSG(
input_vk_tensor->constant_buffer_idx() < 0,
input_vk_tensor->constant_id() < 0,
"Expected constant buffer index for input at index %zu with id %d to be < 0 (since it is non-constant), but got: %d",
input_index,
input_id,
input_vk_tensor->constant_buffer_idx());
input_vk_tensor->constant_id());

const api::ScalarType& input_dtype =
get_scalar_type(input_vk_tensor->datatype());
Expand Down
66 changes: 32 additions & 34 deletions backends/vulkan/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,9 @@

namespace vkgraph;

// Update after any BC breaking changes
// Update after any BC breaking changes.
file_identifier "VK00";

enum VkDatatype : short {
/// IEEE754 single-precision floating-point.
vk_datatype_fp32 = 0,
}

// Abstraction to represent a region of bytes in a raw data buffer. Useful for referencing raw data serialized outside of the flatbuffer.
table VkBytes {
offset: ulong;
length: ulong;
}

table VkTensor {
// type of the tensor elements.
datatype:VkDatatype;
// Array of shape dimensions
dims:[uint];
// Index to the program's constant buffer table, negative value indicates non constant
constant_buffer_idx:int;
// Indicates which shared memory object this tensor uses; negative value indicates the tensor does not share memory
mem_obj_id: int;
}

table VkValue {
value:VkTensor;
}

enum VkArithmeticOpType : short {
vk_arithmetic_op_type_add = 0,
vk_arithmetic_op_type_sub = 1,
Expand All @@ -53,22 +27,46 @@ table VkNode {
debug_handle:uint;
}

enum VkDataType : short {
// IEEE754 single-precision floating-point.
fp32 = 0,
}

table VkTensor {
// Type of the tensor elements.
datatype:VkDataType;
// Shape dimensions.
dims:[uint];
// Index to the program's constant data. Negative indicates tensor is non-constant.
constant_id:int;
// Index to the shared memory object. Negative indicates the tensor doesn't share memory.
mem_obj_id:int;
}

table VkValue {
value:VkTensor;
}

// Abstraction to represent a region of bytes in a raw data buffer. Useful for referencing raw data
// serialized outside of the flatbuffer.
table VkBytes {
offset:ulong;
length:ulong;
}

table VkGraph {
// Schema version.
version:string;

// Objects
chain:[VkNode];
values:[VkValue];

// Ids of external inputs
// Indices
input_ids:[uint];

// Ids of external outputs
output_ids:[uint];

// Tables of constant data, used for constant Values (e.g.
// data field of weight tensors). Each constant is assigned an index into the table
// which are each individually aligned. 0 index is reserved to be pointed to by non-constant
// Tensors
// Raw Objects (e.g. weight tensors and custom shaders)
constants:[VkBytes];
shaders:[VkBytes];
}
Expand Down
57 changes: 29 additions & 28 deletions backends/vulkan/serialization/vulkan_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,6 @@
from typing import List


class VkDatatype(IntEnum):
vk_datatype_fp32 = 0


@dataclass
class VkBytes:
offset: int
length: int


@dataclass
class VkTensor:
datatype: VkDatatype
dims: List[int]
constant_buffer_idx: int
mem_obj_id: int


@dataclass
class VkScalar:
pass


@dataclass
class VkValue:
value: VkTensor


class VkArithmeticOpType(IntEnum):
vk_arithmetic_op_type_add = 0
vk_arithmetic_op_type_sub = 1
Expand All @@ -67,9 +39,38 @@ class VkNode:
debug_handle: int


class VkDataType(IntEnum):
fp32 = 0


@dataclass
class VkTensor:
datatype: VkDataType
dims: List[int]
constant_id: int
mem_obj_id: int


@dataclass
class VkScalar:
pass


@dataclass
class VkValue:
value: VkTensor


@dataclass
class VkBytes:
offset: int
length: int


@dataclass
class VkGraph:
version: str

chain: List[VkNode]
values: List[VkValue]

Expand Down
52 changes: 27 additions & 25 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def get_vk_op_type(
)

@staticmethod
def get_vk_datatype(torch_dtype: dtype) -> vk_graph_schema.VkDatatype:
def get_vk_datatype(torch_dtype: dtype) -> vk_graph_schema.VkDataType:
if torch_dtype == float32:
return vk_graph_schema.VkDatatype.vk_datatype_fp32
return vk_graph_schema.VkDataType.fp32
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

Expand All @@ -75,23 +75,23 @@ def preprocess( # noqa: C901
program: ExportedProgram,
module_compile_spec: List[CompileSpec],
) -> PreprocessResult:
vk_nodes = []
vk_chain = []
vk_values = []
vk_input_ids = []
vk_output_ids = []
const_tensors = []

# Mapping from node in the executorch graph to corresponding VkValue id
node_vk_value_ids = {}
# Mapping from graph Node to schema VkValue.
node_to_value_ids = {}

def create_single_vk_value(node: Node, buffer_idx: int = -1) -> int:
def create_single_vk_value(node: Node, constant_id: int = -1) -> int:
spec = node.meta.get("spec")
assert isinstance(spec, TensorSpec)
new_id = len(vk_values)
if node not in node_vk_value_ids:
node_vk_value_ids[node] = new_id
if node not in node_to_value_ids:
node_to_value_ids[node] = new_id
else:
current_ids = node_vk_value_ids[node]
current_ids = node_to_value_ids[node]
if isinstance(current_ids, int):
current_ids = [current_ids, new_id]
else:
Expand All @@ -107,22 +107,22 @@ def create_single_vk_value(node: Node, buffer_idx: int = -1) -> int:
value=vk_graph_schema.VkTensor(
datatype=VulkanBackend.get_vk_datatype(spec.dtype),
dims=spec.shape,
constant_buffer_idx=buffer_idx,
constant_id=constant_id,
mem_obj_id=mem_obj_id,
)
)
)
return new_id

def create_vk_values_for(node: Node, buffer_idx: int = -1):
def create_vk_values_for(node: Node, constant_id: int = -1):
spec = node.meta.get("spec")

if isinstance(spec, TensorSpec):
return create_single_vk_value(node, buffer_idx)
return create_single_vk_value(node, constant_id)
else:
ids = []
for _ in spec:
ids.append(create_single_vk_value(node, buffer_idx))
ids.append(create_single_vk_value(node, constant_id))
return ids

passes = [
Expand Down Expand Up @@ -154,17 +154,17 @@ def create_vk_values_for(node: Node, buffer_idx: int = -1):
elif node.op == "call_function":
# Op
if (
node.all_input_nodes[0] not in node_vk_value_ids
or node.all_input_nodes[1] not in node_vk_value_ids
node.all_input_nodes[0] not in node_to_value_ids
or node.all_input_nodes[1] not in node_to_value_ids
):
raise AssertionError(
"Cannot find input(s) for current node in node_vk_value_ids. This means this node is being serialized before its input(s) which is not allowed."
"Cannot find input(s) for current node in node_to_value_ids. This means this node is being serialized before its input(s) which is not allowed."
)
vk_nodes.append(
vk_chain.append(
vk_graph_schema.VkNode(
node=vk_graph_schema.VkArithmeticNode(
input1_id=node_vk_value_ids[node.all_input_nodes[0]],
input2_id=node_vk_value_ids[node.all_input_nodes[1]],
input1_id=node_to_value_ids[node.all_input_nodes[0]],
input2_id=node_to_value_ids[node.all_input_nodes[1]],
output_id=create_vk_values_for(node),
op_type=VulkanBackend.get_vk_op_type(
target_name=node.target.__name__, kwargs=node.kwargs
Expand All @@ -177,24 +177,26 @@ def create_vk_values_for(node: Node, buffer_idx: int = -1):
),
)
elif node.op == "get_attr":
buffer_idx = len(const_tensors)
constant_id = len(const_tensors)
const_tensors.append(
getattr(node.graph.owning_module, node.target).contiguous()
)

create_vk_values_for(node, buffer_idx)
create_vk_values_for(node, constant_id)

elif node.op == "output":
if node.all_input_nodes[0] not in node_vk_value_ids:
if node.all_input_nodes[0] not in node_to_value_ids:
raise AssertionError(
"Cannot find input to output node in node_vk_value_ids. This means the output node is being serialized before its corresponding internal node which is not allowed."
"Cannot find input to output node in node_to_value_ids. This means the output node is being serialized before its corresponding internal node which is not allowed."
)
vk_output_ids.append(node_vk_value_ids[node.all_input_nodes[0]])
vk_output_ids.append(node_to_value_ids[node.all_input_nodes[0]])
else:
raise RuntimeError(f"Unsupported op, {node.op}, in Vulkan Preprocess")

# Raw objects (constants and shaders) are populated in the next line's method.
vk_graph = vk_graph_schema.VkGraph(
version="0",
chain=vk_nodes,
chain=vk_chain,
values=vk_values,
input_ids=vk_input_ids,
output_ids=vk_output_ids,
Expand Down

0 comments on commit 99c70f9

Please sign in to comment.