Skip to content

Commit

Permalink
Respect ReadOnly Property in Nanobind Adapter (#1809)
Browse files Browse the repository at this point in the history
Create a SID with const element type if a Nanobind array is defined
read-only. Requires Nanobind 2.1.0 (with merged
wjakob/nanobind#491).

---------

Co-authored-by: Ioannis Magkanaris <[email protected]>
  • Loading branch information
2 people authored and havogt committed Oct 30, 2024
1 parent 88af55b commit 6c600e4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
12 changes: 6 additions & 6 deletions include/gridtools/storage/adapter/nanobind_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,21 @@ namespace gridtools {
class Strides = fully_dynamic_strides<sizeof...(Sizes)>,
class StridesKind = sid::unknown_kind>
auto as_sid(nanobind::ndarray<T, nanobind::shape<Sizes...>, Args...> ndarray,
Strides stride_spec_ = {},
Strides stride_spec = {},
StridesKind = {}) {
using sid::property;
const auto ptr = ndarray.data();
constexpr auto ndim = sizeof...(Sizes);
assert(ndim == ndarray.ndim());
gridtools::array<std::size_t, ndim> shape;
std::copy_n(ndarray.shape_ptr(), ndim, shape.begin());
gridtools::array<std::size_t, ndim> strides_;
std::copy_n(ndarray.stride_ptr(), ndim, strides_.begin());
const auto strides = select_static_strides(stride_spec_, strides_.data());
gridtools::array<std::size_t, ndim> strides;
std::copy_n(ndarray.stride_ptr(), ndim, strides.begin());
const auto static_strides = select_static_strides(stride_spec, strides.data());

return sid::synthetic()
.template set<property::origin>(sid::host_device::simple_ptr_holder<T *>{ptr})
.template set<property::strides>(strides)
.template set<property::origin>(sid::host_device::simple_ptr_holder{ptr})
.template set<property::strides>(static_strides)
.template set<property::strides_kind, StridesKind>()
.template set<property::lower_bounds>(gridtools::array<integral_constant<std::size_t, 0>, ndim>())
.template set<property::upper_bounds>(shape);
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/storage/adapter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ if (${GT_TESTS_ENABLE_PYTHON_TESTS})
FetchContent_Declare(
nanobind
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
GIT_TAG v2.0.0
GIT_TAG v2.1.0
)
FetchContent_MakeAvailable(nanobind)
nanobind_build_library(nanobind-static)
Expand Down
20 changes: 20 additions & 0 deletions tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@ TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) {
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterReadOnly) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<-1, -1>, nb::ro> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray);
using element_t = gridtools::sid::element_type<decltype(sid)>;
static_assert(std::is_same_v<element_t, int const>);

const auto s_origin = sid_get_origin(sid);
const auto s_strides = sid_get_strides(sid);
const auto s_ptr = s_origin();

EXPECT_EQ(s_ptr, data);
EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
Expand Down

0 comments on commit 6c600e4

Please sign in to comment.