From 5f4a81154ce660b7822046f0a41b64e3a67014c2 Mon Sep 17 00:00:00 2001 From: Jorge Pineda <32918197+jorgep31415@users.noreply.github.com> Date: Thu, 5 Sep 2024 16:55:37 -0700 Subject: [PATCH] [ET-VK] Fix negative dim in `normalize_to_dim_index` Differential Revision: D62270925 Pull Request resolved: https://github.com/pytorch/executorch/pull/5118 --- backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h index 45dfceb3f0..4bd8e9b900 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h @@ -32,7 +32,8 @@ constexpr DimIndex kChannel4D = DimIndex::DIM_3RD_LAST; constexpr DimIndex kBatch4D = DimIndex::DIM_4TH_LAST; inline DimIndex normalize_to_dim_index(const api::vTensor& v_in, int32_t dim) { - return static_cast(dim - v_in.dim()); + return dim < 0 ? static_cast(dim) + : static_cast(dim - v_in.dim()); } /*