Skip to content

Commit

Permalink
Creation ops (pytorch#3877)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3877

Add more creation ops (ones/ones_like/zeros/zeros_like). Some are needed in OCR full model. Reuse full's implementation. Register them in Full.cpp

Reviewed By: copyrightly, SS-JIA

Differential Revision: D58247380

fbshipit-source-id: a31249396850a3c8426cda74bd6b6c9595bd2484
  • Loading branch information
Yujie Hui authored and facebook-github-bot committed Jun 6, 2024
1 parent 98f0d82 commit b1e5ba8
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 4 deletions.
4 changes: 4 additions & 0 deletions backends/vulkan/partitioner/supported_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def __contains__(self, op):
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ones.default,
exir_ops.edge.aten.ones_like.default,
exir_ops.edge.aten.zeros.default,
exir_ops.edge.aten.zeros_like.default,
]


Expand Down
18 changes: 14 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/Full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,26 @@ void add_full_node(
}

void full(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(graph, args[0], args[1], args[6]);
return add_full_node(graph, args[0], args[1], args[args.size() - 1]);
}

void full_like(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(graph, args[0], args[1], args[7]);
void zeros(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(
graph, args[0], graph.add_scalar<int64_t>(0), args[args.size() - 1]);
}

void ones(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_full_node(
graph, args[0], graph.add_scalar<int64_t>(1), args[args.size() - 1]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.full.default, full);
VK_REGISTER_OP(aten.full_like.default, full_like);
VK_REGISTER_OP(aten.full_like.default, full);
VK_REGISTER_OP(aten.zeros.default, zeros);
VK_REGISTER_OP(aten.zeros_like.default, zeros);
VK_REGISTER_OP(aten.ones.default, ones);
VK_REGISTER_OP(aten.ones_like.default, ones);
}

} // namespace vkcompute
19 changes: 19 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,25 @@ def get_full_inputs():
return test_suite


@register_test_suite(
[
"aten.zeros.default",
"aten.zeros_like.default",
"aten.ones.default",
"aten.ones_like.default",
]
)
def get_ones_inputs():
test_suite = VkTestSuite(
[
([S1, S2]),
([M, M1, M2]),
([L, M, M1, M2]),
]
)
return test_suite


@register_test_suite(["aten.select.int", "aten.select_copy.int"])
def get_select_int_inputs():
test_suite = VkTestSuite(
Expand Down
52 changes: 52 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,20 @@ def __init__(self):
def forward(self, x):
return torch.full(x.shape, 42.0)

class ZerosModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.zeros(x.shape)

class OnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ones(x.shape)

sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)

self.lower_module_and_test_output(
Expand All @@ -971,6 +985,18 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

self.lower_module_and_test_output(
ZerosModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

self.lower_module_and_test_output(
OnesModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_full_like(self):
class FullLikeModule(torch.nn.Module):
def __init__(self):
Expand All @@ -979,6 +1005,20 @@ def __init__(self):
def forward(self, x):
return torch.full_like(x, 42.0)

class ZerosLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.zeros_like(x)

class OnesLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ones_like(x)

sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)

self.lower_module_and_test_output(
Expand All @@ -987,6 +1027,18 @@ def forward(self, x):
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

self.lower_module_and_test_output(
ZerosLikeModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

self.lower_module_and_test_output(
OnesLikeModule(),
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_reshape(self):
class ReshapeModule(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit b1e5ba8

Please sign in to comment.