diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index ad49687369..6c056cc9d9 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -482,3 +482,9 @@ void execute_graph_and_check_output( } } } + +bool check_close(float a, float b, float atol, float rtol) { + float max = std::max(std::abs(a), std::abs(b)); + float diff = std::abs(a - b); + return diff <= (atol + rtol * max); +} diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index f9969eddbf..bf54944617 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -242,3 +242,9 @@ void print_vector( } std::cout << std::endl; } + +// +// Misc. Utilities +// + +bool check_close(float a, float b, float atol = 1e-4, float rtol = 1e-5); diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 307593d8fd..ee2d119b6b 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -601,7 +601,7 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) { EXPECT_TRUE(data_out.size() == ref_out.size()); for (size_t i = 0; i < data_out.size(); ++i) { - EXPECT_TRUE(data_out[i] == ref_out[i]); + EXPECT_TRUE(check_close(data_out[i], ref_out[i])); } }