diff --git a/extension/tensor/targets.bzl b/extension/tensor/targets.bzl index 2a8f919357..97654094af 100644 --- a/extension/tensor/targets.bzl +++ b/extension/tensor/targets.bzl @@ -18,6 +18,7 @@ def define_common_targets(): ], exported_headers = [ "tensor.h", + "tensor_accessor.h", "tensor_ptr.h", "tensor_ptr_maker.h", ], diff --git a/extension/tensor/tensor_accessor.h b/extension/tensor/tensor_accessor.h new file mode 100644 index 0000000000..362bdb8d72 --- /dev/null +++ b/extension/tensor/tensor_accessor.h @@ -0,0 +1,215 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace extension { +namespace internal { + +/** + * Base class template storing the underlying data with size and stride helpers. + * Inherited by TensorAccessor<> which requires specialization on rank. + */ +template +class TensorAccessorBase { + public: + /// Returns the size of the underlying tensor at the given dimension. + executorch::aten::SizesType size(ssize_t i) const { + ET_CHECK_MSG( + i < dim_ && i >= 0, + "Dimension outside of [0, %zd], got %zd", + dim_ - 1, + i); + return sizes_[i]; + } + + /// Returns the stride of the underlying tensor at the given dimension. + executorch::aten::StridesType stride(ssize_t i) const { + ET_CHECK_MSG( + i < dim_ && i >= 0, + "Dimension outside of [0, %zd], got %zd", + dim_ - 1, + i); + return strides_[i]; + } + + protected: + TensorAccessorBase( + T* data, + const executorch::aten::SizesType* sizes, + const executorch::aten::StridesType* strides, + ssize_t dim) + : data_(data), sizes_(sizes), strides_(strides), dim_(dim) {} + + T* data_; + const executorch::aten::SizesType* sizes_; + const executorch::aten::StridesType* strides_; + ssize_t dim_; +}; + +} // namespace internal + +/** + * TensorAccessor template with data type and rank as template parameters. No + * public constructors, can only be created using make_tensor_accessor from a + * given executorch::aten::Tensor. Use operator[] to index and obtain a lower + * rank accessor or the underlying scalar value. + */ +template +class TensorAccessor : public internal::TensorAccessorBase { + public: + /** + * Index into the the outer most dimension. + * + * @param i Index. + * @return If N > 1, a TensorAccessor with N-1 dimensions. If N == 1, a + * reference to the underlying scalar. Refer to the TensorAccessor + * specialization. + */ + TensorAccessor operator[](ssize_t i) { + return TensorAccessor( + this->data_ + this->strides_[0] * i, + this->sizes_ + 1, + this->strides_ + 1, + N - 1); + } + + /** + * Index into the the outer most dimension. + * + * @param i Index. + * @return If N > 1, a constant TensorAccessor with N-1 dimensions. If N == 1, + * a constant reference to the underlying scalar. Refer to the + * TensorAccessor specialization. + */ + const TensorAccessor operator[](ssize_t i) const { + return TensorAccessor( + this->data_ + this->strides_[0] * i, + this->sizes_ + 1, + this->strides_ + 1, + N - 1); + } + + private: + TensorAccessor( + T* data, + const executorch::aten::SizesType* sizes, + const executorch::aten::StridesType* strides, + ssize_t dim) + : internal::TensorAccessorBase(data, sizes, strides, dim) {} + + template + friend class TensorAccessor; + + template + friend executorch::runtime::Result> + make_tensor_accessor(const executorch::aten::Tensor& t); +}; + +/** + * TensorAccessor specialization for N == 1, where operator[] returns a + * reference to the underlying scalar. + */ +template +class TensorAccessor : public internal::TensorAccessorBase { + public: + /** + * Index into the the outer most dimension. + * + * @param i Index. + * @return Reference to the underlying scalar. + */ + T& operator[](ssize_t i) { + return this->data_[this->strides_[0] * i]; + } + + /** + * Index into the the outer most dimension. + * + * @param i Index. + * @return Constant reference to the underlying scalar. + */ + const T& operator[](ssize_t i) const { + return this->data_[this->strides_[0] * i]; + } + + private: + TensorAccessor( + T* data, + const executorch::aten::SizesType* sizes, + const executorch::aten::StridesType* strides, + ssize_t dim) + : internal::TensorAccessorBase(data, sizes, strides, dim) {} + + template + friend class TensorAccessor; + + template + friend executorch::runtime::Result> + make_tensor_accessor(const executorch::aten::Tensor& t); +}; + +/** + * Creates a TensorAccessor from the given tensor. The number of dimension + * N and the data type T's size must match those of the input tensor. For + * Executorch tensors, non-trivial dimension order is not supported. + * + * @param tensor Origin tensor. The TensorImpl inside must outlive the returned + * TensorAccessor. + * @return TensorAccessor of the input tensor. + * @retval Error::InvalidArgument Mismatch on data type or number of dimensions. + * @retval Error::NotSupported Input tensor has non-trivial dimension onrder. + */ +template +executorch::runtime::Result> make_tensor_accessor( + const executorch::aten::Tensor& tensor) { + static_assert( + N > 0, + "TensorAccessor is used for indexing tensors, for scalar use *_data_ptr()"); + + if (N != tensor.dim()) { + ET_LOG( + Error, "Expecting %zd dimensions but tensor has %zd.", N, tensor.dim()); + return executorch::runtime::Error::InvalidArgument; + } + + if (sizeof(T) != tensor.element_size()) { + ET_LOG( + Error, + "Size of data type template argument (%zd) not equal to tensor element size (%zd)", + sizeof(T), + tensor.element_size()); + return executorch::runtime::Error::InvalidArgument; + } + +#ifndef USE_ATEN_LIB + auto dim_order = tensor.dim_order(); + for (ssize_t i = 0; i < dim_order.size(); i++) { + if (dim_order[i] != i) { + ET_LOG(Error, "Non-trival dim_order not supported."); + return executorch::runtime::Error::NotSupported; + } + } +#endif + + T* ptr = nullptr; + if constexpr (std::is_const_v) { + ptr = tensor.const_data_ptr(); + } else { + ptr = tensor.mutable_data_ptr(); + } + return TensorAccessor( + ptr, tensor.sizes().data(), tensor.strides().data(), N); +} + +} // namespace extension +} // namespace executorch diff --git a/extension/tensor/test/targets.bzl b/extension/tensor/test/targets.bzl index 3c81ac8def..29c8bff84b 100644 --- a/extension/tensor/test/targets.bzl +++ b/extension/tensor/test/targets.bzl @@ -13,6 +13,7 @@ def define_common_targets(): runtime.cxx_test( name = "test" + aten_suffix, srcs = [ + "tensor_accessor_test.cpp", "tensor_ptr_maker_test.cpp", "tensor_ptr_test.cpp", ], diff --git a/extension/tensor/test/tensor_accessor_test.cpp b/extension/tensor/test/tensor_accessor_test.cpp new file mode 100644 index 0000000000..67c4d2df51 --- /dev/null +++ b/extension/tensor/test/tensor_accessor_test.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +using executorch::extension::make_tensor_accessor; +using executorch::extension::make_tensor_ptr; +using executorch::extension::TensorAccessor; + +class TensorAccessorTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + executorch::runtime::runtime_init(); + } +}; + +TEST_F(TensorAccessorTest, From1DTensor) { + constexpr int32_t kN = 16; + std::vector data(kN, 0); + for (int32_t i = 0; i < kN; i++) { + data[i] = i; + } + + auto tensor = + make_tensor_ptr({kN}, data.data(), executorch::aten::ScalarType::Byte); + auto tensor_accessor = make_tensor_accessor(*tensor.get()); + EXPECT_TRUE(tensor_accessor.ok()); + for (int32_t i = 0; i < kN; i++) { + EXPECT_EQ(tensor_accessor.get()[i], i); + } +} + +int32_t +value_at_pos_in_4d_int_tensor(int32_t n, int32_t c, int32_t h, int32_t w) { + // just encode the position into the value, assuming dimensions fit in 8 bits + return (n << 24) | (c << 16) | (h << 8) | w; +} + +void check_4d_int_tensor_accessor( + TensorAccessor accessor, + int32_t N, + int32_t C, + int32_t H, + int32_t W) { + for (int32_t n = 0; n < N; n++) { + for (int32_t c = 0; c < C; c++) { + for (int32_t h = 0; h < H; h++) { + for (int32_t w = 0; w < W; w++) { + EXPECT_EQ( + accessor[n][c][h][w], value_at_pos_in_4d_int_tensor(n, c, h, w)); + } + } + } + } +} + +TEST_F(TensorAccessorTest, From4DTensor) { + constexpr int32_t kN = 2; + constexpr int32_t kC = 8; + constexpr int32_t kH = 4; + constexpr int32_t kW = 6; + std::vector data(kN * kC * kH * kW, 0); + size_t idx = 0; + for (int32_t n = 0; n < kN; n++) { + for (int32_t c = 0; c < kC; c++) { + for (int32_t h = 0; h < kH; h++) { + for (int32_t w = 0; w < kW; w++) { + data[idx++] = value_at_pos_in_4d_int_tensor(n, c, h, w); + } + } + } + } + + auto tensor = make_tensor_ptr( + {kN, kC, kH, kW}, data.data(), executorch::aten::ScalarType::Int); + auto accessor = make_tensor_accessor(*tensor.get()); + EXPECT_TRUE(accessor.ok()); + check_4d_int_tensor_accessor(accessor.get(), kN, kC, kH, kW); +} + +#ifdef USE_ATEN_LIB // Non-contiguous tensor is only allowed in ATen mode. +TEST_F(TensorAccessorTest, FromNonContiguousTensor) { + constexpr int32_t kN = 2; + constexpr int32_t kC = 8; + constexpr int32_t kH = 4; + constexpr int32_t kW = 6; + constexpr int32_t kW_padded = 8; + std::vector data(kN * kC * kH * kW_padded, 0); + std::array sizes = {kN, kC, kH, kW}; + std::array strides = { + kC * kH * kW_padded, + 1, // channel last + kC * kW_padded, // width is padded + kC}; + + size_t idx = 0; + for (int32_t n = 0; n < kN; n++) { + for (int32_t h = 0; h < kH; h++) { + for (int32_t w = 0; w < kW_padded; w++) { + for (int32_t c = 0; c < kC; c++) { + data[idx++] = value_at_pos_in_4d_int_tensor(n, c, h, w); + } + } + } + } + + auto tensor = at::from_blob( + data.data(), sizes, strides, at::TensorOptions().dtype(at::kInt)); + auto accessor = make_tensor_accessor(tensor); + EXPECT_TRUE(accessor.ok()); + check_4d_int_tensor_accessor(accessor.get(), kN, kC, kH, kW); +} +#endif // ifdef USE_ATEN_LIB + +TEST_F(TensorAccessorTest, FailOnIncorrectDtypeOrRank) { + constexpr int32_t kN = 16; + std::vector data(kN, 0); + auto tensor = make_tensor_ptr({kN}, data.data()); + + // Tensor has rank 1 but creating accessor with rank 2. + auto fail1 = make_tensor_accessor(*tensor.get()); + EXPECT_FALSE(fail1.ok()); + + // Tensor has dtype float but creating accoessor with dtype uint8_t. + auto fail2 = make_tensor_accessor(*tensor.get()); + EXPECT_FALSE(fail2.ok()); +} + +#ifndef USE_ATEN_LIB // Dim order is only defined for portable Tensor +TEST_F(TensorAccessorTest, FailOnNonTrivialDimOrder) { + constexpr int32_t kN = 8; + constexpr int32_t kM = 16; + std::vector data(kN * kM, 0); + auto tensor = make_tensor_ptr( + {kN, kM}, + data.data(), + /*dim_order=*/{1, 0}, + /*strides=*/{1, kN}); + + // Non trivial dim order is not supported. + auto fail = make_tensor_accessor(*tensor.get()); + EXPECT_FALSE(fail.ok()); +} +#endif // ifndef USE_ATEN_LIB