Skip to content

Commit

Permalink
Serialize list types from function args (pytorch#2404)
Browse files Browse the repository at this point in the history
Summary:
bypass-github-export-checks

Pull Request resolved: pytorch#2404

In pytorch#2271, we already added
- IntList
- DoubleList
- BoolList
- ValueList

to the schema and the runtime's Value class. Their serialization was incomplete missing two components:
1. Receiving a list in `torch.fx.Node.args`.
2. Receiving a non-tensor in `torch.fx.Node`.

This change completes #1.

Also, this change fixes a bug where values type `bool` matches both types `bool` and `int` and hence were being added twice.

If our type support grows more complex, we can consider using our own types similar to the core Executorch runtime: https://github.com/pytorch/executorch/blob/689796499024fc4a133318d707f4c10db73da967/exir/emit/_emitter.py#L158-L166
ghstack-source-id: 218539049
exported-using-ghexport

Reviewed By: SS-JIA

Differential Revision: D54708353

fbshipit-source-id: 8641647b515e201ea63db67115c01c1532ad6566
  • Loading branch information
jorgep31415 authored and facebook-github-bot committed Mar 13, 2024
1 parent 16a3156 commit 21cbfd6
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 14 deletions.
42 changes: 42 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ class GraphBuilder {
ref_mapping_[fb_id] = ref;
}

template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
add_scalar_list_to_graph(const uint32_t fb_id, std::vector<T>&& value) {
ValueRef ref = compute_graph_->add_scalar_list(std::move(value));
ref_mapping_[fb_id] = ref;
}

void add_value_list_to_graph(
const uint32_t fb_id,
std::vector<ValueRef>&& value) {
ValueRef ref = compute_graph_->add_value_list(std::move(value));
ref_mapping_[fb_id] = ref;
}

void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) {
const auto fb_str = value->value_as_String()->string_val();
std::string string(fb_str->cbegin(), fb_str->cend());
Expand All @@ -150,6 +164,34 @@ class GraphBuilder {
case vkgraph::GraphTypes::VkTensor:
add_tensor_to_graph(fb_id, value->value_as_VkTensor());
break;
case vkgraph::GraphTypes::IntList:
add_scalar_list_to_graph(
fb_id,
std::vector<int64_t>(
value->value_as_IntList()->items()->cbegin(),
value->value_as_IntList()->items()->cend()));
break;
case vkgraph::GraphTypes::DoubleList:
add_scalar_list_to_graph(
fb_id,
std::vector<double>(
value->value_as_DoubleList()->items()->cbegin(),
value->value_as_DoubleList()->items()->cend()));
break;
case vkgraph::GraphTypes::BoolList:
add_scalar_list_to_graph(
fb_id,
std::vector<bool>(
value->value_as_BoolList()->items()->cbegin(),
value->value_as_BoolList()->items()->cend()));
break;
case vkgraph::GraphTypes::ValueList:
add_value_list_to_graph(
fb_id,
std::vector<ValueRef>(
value->value_as_ValueList()->items()->cbegin(),
value->value_as_ValueList()->items()->cend()));
break;
case vkgraph::GraphTypes::String:
add_string_to_graph(fb_id, value);
break;
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ ValueRef ComputeGraph::add_staging(
return idx;
}

ValueRef ComputeGraph::add_value_list(std::vector<ValueRef>&& value) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(std::move(value));
return idx;
}

ValueRef ComputeGraph::add_string(std::string&& str) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(std::move(str));
Expand Down
14 changes: 8 additions & 6 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,13 @@ class ComputeGraph final {

template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
add_scalar_list(std::vector<T>&& values);
add_scalar(T value);

template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
add_scalar(T value);
add_scalar_list(std::vector<T>&& value);

ValueRef add_value_list(std::vector<ValueRef>&& value);

ValueRef add_string(std::string&& str);

Expand Down Expand Up @@ -212,17 +214,17 @@ class ComputeGraph final {

template <typename T>
inline typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
ComputeGraph::add_scalar_list(std::vector<T>&& values) {
ComputeGraph::add_scalar(T value) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(std::move(values));
values_.emplace_back(value);
return idx;
}

template <typename T>
inline typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
ComputeGraph::add_scalar(T value) {
ComputeGraph::add_scalar_list(std::vector<T>&& value) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(value);
values_.emplace_back(std::move(value));
return idx;
}

Expand Down
54 changes: 46 additions & 8 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Union
from typing import cast, List, Optional, Union

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

Expand All @@ -15,8 +15,8 @@
from torch.export import ExportedProgram
from torch.fx import Node

_ScalarType = Union[int, bool, float]
_Argument = Union[Node, int, bool, float, str]
_ScalarType = Union[bool, int, float]
_Argument = Union[Node, List[Node], _ScalarType, List[_ScalarType], str]


class VkGraphBuilder:
Expand Down Expand Up @@ -150,14 +150,46 @@ def create_tensor_values(self, node: Node) -> int:
"Creating values for nodes with collection types is not supported yet."
)

def create_value_list_value(self, arg: List[Node]) -> int:
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.ValueList(
items=[self.get_or_create_value_for(e) for e in arg]
)
)
)
return len(self.values) - 1

def create_scalar_value(self, scalar: _ScalarType) -> int:
new_id = len(self.values)
if isinstance(scalar, int):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
if isinstance(scalar, float):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
if isinstance(scalar, bool):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar)))
elif isinstance(scalar, int):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
elif isinstance(scalar, float):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
return new_id

def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
new_id = len(self.values)
if isinstance(arg[0], bool):
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
)
)
elif isinstance(arg[0], int):
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.IntList(items=[cast(int, e) for e in arg])
)
)
elif isinstance(arg[0], float):
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg])
)
)
return new_id

def create_string_value(self, string: str) -> int:
Expand All @@ -174,8 +206,14 @@ def get_or_create_value_for(self, arg: _Argument):
return self.node_to_value_ids[arg]
# Return id for a newly created value
return self.create_tensor_values(arg)
elif isinstance(arg, (int, float, bool)):
elif isinstance(arg, list) and isinstance(arg[0], Node):
# pyre-ignore[6]
return self.create_value_list_value(arg)
elif isinstance(arg, _ScalarType):
return self.create_scalar_value(arg)
elif isinstance(arg, list) and isinstance(arg[0], _ScalarType):
# pyre-ignore[6]
return self.create_scalar_list_value(arg)
elif isinstance(arg, str):
return self.create_string_value(arg)
else:
Expand Down

0 comments on commit 21cbfd6

Please sign in to comment.