Skip to content

Commit

Permalink
Read-only array exchange via the buffer protocol
Browse files Browse the repository at this point in the history
This PR adapts nanobind so that it can both receive and return arrays
representing read-only memory. This is communicated via the standard
``const`` type modifier, e.g. by replacing ``nb::ndarray<float, ...>``
with ``nb::ndarray<const float, ...>``.

The PR also adapts accessors (``data()``, ``operator()``) so that they
only return constant references/pointers in that case.

The change is for now specific to the buffer protocol used to exchange
data with NumPy. The DLPack interface (with PyTorch, Tensorflow, etc.)
ignores the read-only annotation. This may change at some point in the
future when an upcoming DLPack version with a read-only bit is more
widely deployed.
  • Loading branch information
wjakob committed May 25, 2023
1 parent 5efcdfa commit 5e6b210
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 48 deletions.
29 changes: 24 additions & 5 deletions docs/api_extra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -496,16 +496,16 @@ in a :ref:`separate section <ndarrays>`.

Return the stride of dimension `i`.

.. cpp:function:: int64_t* shape_ptr() const
.. cpp:function:: const int64_t* shape_ptr() const

Return a pointer to the shape array. Note that the return type is
``int64_t*``, which may be unexpected as the scalar version
``const int64_t*``, which may be unexpected as the scalar version
:cpp:func:`shape()` casts its result to a ``size_t``.

This is a consequence of the DLPack tensor representation that uses
signed 64-bit integers for all of these fields.

.. cpp:function:: int64_t* stride_ptr() const
.. cpp:function:: const int64_t* stride_ptr() const

Return pointer to the stride array.

Expand All @@ -526,11 +526,12 @@ in a :ref:`separate section <ndarrays>`.

.. cpp:function:: const Scalar * data() const

Return a mutable pointer to the array data.
Return a const pointer to the array data.

.. cpp:function:: Scalar * data()

Return a const pointer to the array data.
Return a mutable pointer to the array data. Only enabled when `Scalar` is
not itself ``const``.

.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)

Expand Down Expand Up @@ -604,6 +605,24 @@ The :cpp:class:`ndarray\<..\> <ndarray>` class admits optional template
parameters. They constrain the type of array arguments that may be passed to a
function.

The following are supported:

Data type
+++++++++

The data type of the underlying scalar element. The following are supported.

- ``[u]int8_t`` up to ``[u]int64_t`` and other variations (``unsigned long long``, etc.)
- ``float``, ``double``
- ``bool``

Annotate the data type with ``const`` to indicate a read-only array. Note that
only the buffer protocol/NumPy interface considers ``const``-ness at the
moment; data exchange with other array libraries will ignore this annotation.

nanobind does not support non-standard types as documented in the section on
:ref:`dtype limitations <dtype_restrictions>`.

Shape
+++++

Expand Down
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Version 1.3.0 (TBD)
* Refined compiler and linker flags across platforms to ensure compact binaries
especially in ``NB_STATIC`` builds. (commit `5ead9f
<https://github.com/wjakob/nanobind/commit/5ead9ff348a2ef0df8231e6480607a5b0623a16b>`__)
* The :cpp:class:`nb::ndarray\<..\> <ndarray>` class can now use the buffer
protocol to receive and return arrays representing read-only memory. (PR
`#217 <https://github.com/wjakob/nanobind/pull/217>`__)t.
* Reduced the number of exception-related exports to further crunch
``libnanobind``. (commit `763962
<https://github.com/wjakob/nanobind/commit/763962b8ce76414148089ef6a68cff97d7cc66ce>`__).
Expand Down
34 changes: 23 additions & 11 deletions docs/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ The following constraints are available
- A scalar type (``float``, ``uint8_t``, etc.) constrains the representation
of the ndarray.

- This scalar type can be further annotated with ``const``, which is necessary
if you plan to call nanobind functions with arrays that do not permit write
access.

- The :cpp:class:`nb::shape <shape>` annotation simultaneously constrains the
number of array dimensions and the size per dimension. A :cpp:var:`nb::any
<any>` entry leaves the corresponding dimension unconstrained.
Expand Down Expand Up @@ -176,8 +180,9 @@ conversion. This, e.g., makes possible to call a function expecting a
``float32`` array with ``float64`` data. Implicit conversions create
temporary ndarrays containing a copy of the data, which can be
undesirable. To suppress then, add a
:cpp:func:`nb::arg("ndarray").noconvert() <arg::noconvert>`
:cpp:func:`"ndarray"_a.noconvert() <arg::noconvert>` or
:cpp:func:`nb::arg("my_array_arg").noconvert() <arg::noconvert>`
or
:cpp:func:`"my_array_arg"_a.noconvert() <arg::noconvert>` or
function binding annotation.

Binding functions that return arrays
Expand All @@ -188,23 +193,26 @@ to CPU/GPU memory, and what framework (NumPy/..) should be used to encapsulate
the data.

The following simple binding declaration shows how to return a ``2x4``
NumPy floating point matrix.
NumPy floating point matrix that does not permit write access.

.. code-block:: cpp
float data[] = { 1, 2, 3, 4, 5, 6, 7, 8 };
// at top level
const float data[] = { 1, 2, 3, 4, 5, 6, 7, 8 };
m.def("ret_numpy", []() {
size_t shape[2] = { 2, 4 };
return nb::ndarray<nb::numpy, float, nb::shape<2, nb::any>>(
data, /* ndim = */ 2, shape);
});
NB_MODULE(my_ext, m) {
m.def("ret_numpy", []() {
size_t shape[2] = { 2, 4 };
return nb::ndarray<nb::numpy, const float, nb::shape<2, nb::any>>(
data, /* ndim = */ 2, shape);
});
}
The auto-generated docstring of this function is:

.. code-block:: python
ret_pytorch() -> np.ndarray[float32, shape=(2, *)]
ret_pytorch() -> np.ndarray[float32, writable=False, shape=(2, *)]
Calling it in Python yields

Expand All @@ -230,7 +238,7 @@ values:

Note that shape and order annotations like :cpp:class:`nb::shape <shape>` and
:cpp:class:`nb::c_contig <c_contig>` enter into the docstring, but nanobind
wont spend time on additional checks. It trusts that your method returns what
won't spend time on additional checks. It trusts that your method returns what
it declares. Furthermore, non-CPU ndarrays must be explicitly indicate the
device type and device ID using special parameters of the :cpp:func:`ndarray()
<ndarray::ndarray()>` constructor shown below. Device types indicated via
Expand Down Expand Up @@ -324,3 +332,7 @@ nanobind's :cpp:class:`nb::ndarray\<...\> <ndarray>` is based on the `DLPack
<https://github.com/dmlc/dlpack>`__ array exchange protocol, which causes it to
be more restrictive. Presently supported dtypes include signed/unsigned
integers, floating point values, and boolean values.

Nanobind can receive and return read-only arrays via the buffer protocol used
to exchange data with NumPy. The DLPack interface currently ignores this
annotation.
4 changes: 2 additions & 2 deletions include/nanobind/nb_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ NB_CORE ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
NB_CORE ndarray_handle *ndarray_create(void *value, size_t ndim,
const size_t *shape, PyObject *owner,
const int64_t *strides,
dlpack::dtype *dtype, int32_t device,
int32_t device_id);
dlpack::dtype *dtype, bool ro,
int32_t device, int32_t device_id);

/// Increase the reference count of the given ndarray object; returns a pointer
/// to the underlying DLTensor
Expand Down
61 changes: 43 additions & 18 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ struct ndarray_req {
size_t *shape = nullptr;
bool req_shape = false;
bool req_dtype = false;
bool req_ro = false;
char req_order = '\0';
uint8_t req_device = 0;
};
Expand All @@ -126,35 +127,44 @@ template <typename T> struct ndarray_arg<T, enable_if_t<std::is_floating_point_v
static constexpr size_t size = 0;

static constexpr auto name =
const_name("dtype=float") + const_name<sizeof(T) * 8>();
const_name("dtype=float") +
const_name<sizeof(T) * 8>() +
const_name<std::is_const_v<T>>(", writable=False", "");

static void apply(ndarray_req &tr) {
tr.dtype = dtype<T>();
tr.req_dtype = true;
tr.req_ro = std::is_const_v<T>;
}
};

template <typename T> struct ndarray_arg<T, enable_if_t<std::is_integral_v<T> && !std::is_same_v<T, bool>>> {
static constexpr size_t size = 0;

static constexpr auto name =
const_name("dtype=") + const_name<std::is_unsigned_v<T>>("u", "") +
const_name("int") + const_name<sizeof(T) * 8>();
const_name("dtype=") +
const_name<std::is_unsigned_v<T>>("u", "") +
const_name("int") + const_name<sizeof(T) * 8>() +
const_name<std::is_const_v<T>>(", writable=False", "");

static void apply(ndarray_req &tr) {
tr.dtype = dtype<T>();
tr.req_dtype = true;
tr.req_ro = std::is_const_v<T>;
}
};

template <typename T> struct ndarray_arg<T, enable_if_t<std::is_same_v<T, bool>>> {
static constexpr size_t size = 0;

static constexpr auto name = const_name("dtype=bool");
static constexpr auto name =
const_name("dtype=bool") +
const_name<std::is_const_v<T>>(", writable=False", "");

static void apply(ndarray_req &tr) {
tr.dtype = dtype<T>();
tr.req_dtype = true;
tr.req_ro = std::is_const_v<T>;
}
};

Expand Down Expand Up @@ -248,17 +258,17 @@ template <typename... Args> class ndarray {
m_dltensor = *detail::ndarray_inc_ref(handle);
}

ndarray(void *value,
ndarray(std::conditional_t<std::is_const_v<Scalar>, const void *, void *> value,
size_t ndim,
const size_t *shape,
handle owner = nanobind::handle(),
const int64_t *strides = nullptr,
dlpack::dtype dtype = nanobind::dtype<Scalar>(),
int32_t device_type = device::cpu::value,
int32_t device_id = 0) {
m_handle =
detail::ndarray_create(value, ndim, shape, owner.ptr(), strides,
&dtype, device_type, device_id);
m_handle = detail::ndarray_create(
(void *) value, ndim, shape, owner.ptr(), strides, &dtype,
std::is_const_v<Scalar>, device_type, device_id);
m_dltensor = *detail::ndarray_inc_ref(m_handle);
}

Expand Down Expand Up @@ -296,8 +306,8 @@ template <typename... Args> class ndarray {
size_t ndim() const { return (size_t) m_dltensor.ndim; }
size_t shape(size_t i) const { return (size_t) m_dltensor.shape[i]; }
int64_t stride(size_t i) const { return m_dltensor.strides[i]; }
int64_t* shape_ptr() const { return m_dltensor.shape; }
int64_t* stride_ptr() const { return m_dltensor.strides; }
const int64_t* shape_ptr() const { return m_dltensor.shape; }
const int64_t* stride_ptr() const { return m_dltensor.strides; }
bool is_valid() const { return m_handle != nullptr; }
int32_t device_type() const { return m_dltensor.device.device_type; }
int32_t device_id() const { return m_dltensor.device.device_id; }
Expand All @@ -317,10 +327,27 @@ template <typename... Args> class ndarray {
return (const Scalar *)((const uint8_t *) m_dltensor.data + m_dltensor.byte_offset);
}

Scalar *data() { return (Scalar *)((uint8_t *) m_dltensor.data + m_dltensor.byte_offset); }
template <typename T = Scalar, std::enable_if_t<!std::is_const_v<T>, int> = 1>
Scalar *data() {
return (Scalar *) ((uint8_t *) m_dltensor.data +
m_dltensor.byte_offset);
}

template <typename T = Scalar,
std::enable_if_t<!std::is_const_v<T>, int> = 1, typename... Ts>
NB_INLINE auto &operator()(Ts... indices) {
return *(Scalar *) ((uint8_t *) m_dltensor.data +
byte_offset(indices...));
}

template <typename... Ts> NB_INLINE const auto & operator()(Ts... indices) const {
return *(const Scalar *) ((const uint8_t *) m_dltensor.data +
byte_offset(indices...));
}

private:
template <typename... Ts>
NB_INLINE auto& operator()(Ts... indices) {
NB_INLINE int64_t byte_offset(Ts... indices) const {
static_assert(
!std::is_same_v<Scalar, void>,
"To use nb::ndarray::operator(), you must add a scalar type "
Expand All @@ -331,15 +358,13 @@ template <typename... Args> class ndarray {
"annotation to the ndarray template parameters.");
static_assert(sizeof...(Ts) == Info::shape_type::size,
"nb::ndarray::operator(): invalid number of arguments");

int64_t counter = 0, index = 0;
size_t counter = 0;
int64_t index = 0;
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);
return (Scalar &) *(
(uint8_t *) m_dltensor.data + m_dltensor.byte_offset +
index * sizeof(typename Info::scalar_type));

return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
}

private:
detail::ndarray_handle *m_handle = nullptr;
dlpack::dltensor m_dltensor;
};
Expand Down
16 changes: 10 additions & 6 deletions src/nb_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct ndarray_handle {
bool free_shape;
bool free_strides;
bool call_deleter;
bool ro;
};

static void nb_ndarray_dealloc(PyObject *self) {
Expand Down Expand Up @@ -107,7 +108,7 @@ static int nd_ndarray_tpbuffer(PyObject *exporter, Py_buffer *view, int) {

view->ndim = t.ndim;
view->len = len;
view->readonly = false;
view->readonly = self->th->ro;
view->suboffsets = nullptr;
view->internal = nullptr;
view->strides = strides.release();
Expand Down Expand Up @@ -157,11 +158,12 @@ static PyTypeObject *nd_ndarray_tp() noexcept {
return tp;
}

static PyObject *dlpack_from_buffer_protocol(PyObject *o) {
static PyObject *dlpack_from_buffer_protocol(PyObject *o, bool ro) {
scoped_pymalloc<Py_buffer> view;
scoped_pymalloc<managed_dltensor> mt;

if (PyObject_GetBuffer(o, view.get(), PyBUF_RECORDS)) {
if (PyObject_GetBuffer(o, view.get(),
ro ? PyBUF_RECORDS_RO : PyBUF_RECORDS)) {
PyErr_Clear();
return nullptr;
}
Expand Down Expand Up @@ -307,7 +309,7 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,

// Try creating a ndarray via the buffer protocol
if (!capsule.is_valid())
capsule = steal(dlpack_from_buffer_protocol(o));
capsule = steal(dlpack_from_buffer_protocol(o, req->req_ro));

if (!capsule.is_valid())
return nullptr;
Expand Down Expand Up @@ -452,6 +454,7 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
result->owner = nullptr;
result->free_shape = false;
result->call_deleter = true;
result->ro = req->req_ro;
if (is_pycapsule) {
result->self = nullptr;
} else {
Expand Down Expand Up @@ -514,8 +517,8 @@ void ndarray_dec_ref(ndarray_handle *th) noexcept {

ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in,
PyObject *owner, const int64_t *strides_in,
dlpack::dtype *dtype, int32_t device_type,
int32_t device_id) {
dlpack::dtype *dtype, bool ro,
int32_t device_type, int32_t device_id) {
/* DLPack mandates 256-byte alignment of the 'DLTensor::data' field, but
PyTorch unfortunately ignores the 'byte_offset' value.. :-( */
#if 0
Expand Down Expand Up @@ -570,6 +573,7 @@ ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in,
result->free_shape = true;
result->free_strides = true;
result->call_deleter = false;
result->ro = ro;
Py_XINCREF(owner);
return result.release();
}
Expand Down
Loading

0 comments on commit 5e6b210

Please sign in to comment.