Skip to content

Commit

Permalink
Add tests for zero-dim tensors (pytorch#3644)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3644

Turns out zero dim tensors don't need anything special to be enabled. Therefore just add test cases for them.

Reviewed By: copyrightly

Differential Revision: D57463151

fbshipit-source-id: 368199e96970b3334d54af7d0a892898af0bf9f4
  • Loading branch information
SS-JIA authored and facebook-github-bot committed May 17, 2024
1 parent 6efe44c commit 4008600
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
13 changes: 13 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,19 @@ def forward(self, x, y, w):

self.lower_module_and_test_output(add_module, sample_inputs)

def test_vulkan_backend_zero_dim_tensor(self):
class ZeroDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.zero = torch.full([], 1.3, dtype=torch.float32)

def forward(self, x):
return x + self.zero

internal_data_module = ZeroDimModule()
sample_inputs = (torch.rand(size=(2, 3), dtype=torch.float32),)
self.lower_module_and_test_output(internal_data_module, sample_inputs)

def test_vulkan_backend_internal_data(self):
class InternalDataModule(torch.nn.Module):
def __init__(self):
Expand Down
45 changes: 45 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,51 @@ TEST(VulkanComputeGraphTest, test_values_string) {
EXPECT_TRUE(stored == "hello, world");
}

TEST(VulkanComputeGraphTest, test_zero_dim_tensor) {
GraphConfig config;
ComputeGraph graph(config);

std::vector<int64_t> size_big = {7, 3, 5};
std::vector<int64_t> size_small = {};

// Build graph

IOValueRef a = graph.add_input_tensor(size_big, api::kFloat);
IOValueRef b = graph.add_input_tensor(size_small, api::kFloat);

IOValueRef out = {};

out.value = graph.add_tensor(size_big, api::kFloat);

auto addFn = VK_GET_OP_FN("aten.add.Tensor");
addFn(graph, {a.value, b.value, kDummyValueRef, out.value});

out.staging = graph.set_output_tensor(out.value);

graph.prepare();
graph.encode_execute();

// Run graph

for (float i = 5.0f; i < 30.0f; i += 10.0f) {
float val_a = i + 2.0f;
float val_b = i + 1.5f;
float val_c = val_a + val_b;

fill_vtensor(graph, a, val_a);
fill_vtensor(graph, b, val_b);

graph.execute();

EXTRACT_TENSOR(out);

// Sanity check that the values are correct
for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) {
CHECK_VALUE(data_out, i, val_c);
}
}
}

TEST(VulkanComputeGraphTest, test_simple_graph) {
GraphConfig config;
ComputeGraph graph(config);
Expand Down

0 comments on commit 4008600

Please sign in to comment.