diff --git a/src/array.rs b/src/array.rs index 40b8eb361..cdaead703 100644 --- a/src/array.rs +++ b/src/array.rs @@ -420,6 +420,9 @@ impl PyArray { /// Creates a NumPy array backed by `array` and ties its ownership to the Python object `container`. /// + /// The resulting NumPy array will be writeable from Python space. If this is undesireable, use + /// [PyReadwriteArray::make_nonwriteable]. + /// /// # Safety /// /// `container` is set as a base object of the returned array which must not be dropped until `container` is dropped. diff --git a/src/borrow/mod.rs b/src/borrow/mod.rs index a1cb203e3..bcfb4028c 100644 --- a/src/borrow/mod.rs +++ b/src/borrow/mod.rs @@ -15,7 +15,7 @@ //! ```rust //! # use std::panic::{catch_unwind, AssertUnwindSafe}; //! # -//! use numpy::{PyArray1, PyArrayMethods}; +//! use numpy::{PyArray1, PyArrayMethods, npyffi::flags}; //! use ndarray::Zip; //! use pyo3::{Python, Bound}; //! @@ -175,6 +175,7 @@ use crate::array::{PyArray, PyArrayMethods}; use crate::convert::NpyIndex; use crate::dtype::Element; use crate::error::{BorrowError, NotContiguousError}; +use crate::npyffi::flags; use crate::untyped_array::PyUntypedArrayMethods; use shared::{acquire, acquire_mut, release, release_mut}; @@ -453,6 +454,18 @@ where unsafe { &*(self as *const Self as *const Self::Target) } } } +impl<'py, T, D> From> for PyReadonlyArray<'py, T, D> +where + T: Element, + D: Dimension, +{ + fn from(value: PyReadwriteArray<'py, T, D>) -> Self { + let array = value.array.clone(); + ::std::mem::drop(value); + Self::try_new(array) + .expect("releasing an exclusive reference should immediately permit a shared reference") + } +} impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadwriteArray<'py, T, D> { fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult { @@ -494,6 +507,23 @@ where { unsafe { self.array.get_mut(index) } } + + /// Clear the [`WRITEABLE` flag][writeable] from the underlying NumPy array. + /// + /// Calling this will prevent any further [PyReadwriteArray]s from being taken out. Python + /// space can reset this flag, unless the additional flag [`OWNDATA`][owndata] is unset. Such + /// an array can be created from Rust space by using [PyArray::borrow_from_array_bound]. + /// + /// [writeable]: https://numpy.org/doc/stable/reference/c-api/array.html#c.NPY_ARRAY_WRITEABLE + /// [owndata]: https://numpy.org/doc/stable/reference/c-api/array.html#c.NPY_ARRAY_OWNDATA + pub fn make_nonwriteable(self) -> PyReadonlyArray<'py, T, D> { + // SAFETY: consuming the only extant mutable reference guarantees we cannot invalidate an + // existing reference, nor allow the caller to keep hold of one. + unsafe { + (*self.as_array_ptr()).flags &= !flags::NPY_ARRAY_WRITEABLE; + } + self.into() + } } #[cfg(feature = "nalgebra")] diff --git a/tests/borrow.rs b/tests/borrow.rs index 08e6abcc2..356046c65 100644 --- a/tests/borrow.rs +++ b/tests/borrow.rs @@ -348,6 +348,20 @@ fn resize_using_exclusive_borrow() { }); } +#[test] +fn can_make_python_array_nonwriteable() { + Python::with_gil(|py| { + let array = PyArray1::::zeros_bound(py, 10, false); + let locals = [("array", &array)].into_py_dict_bound(py); + array.readwrite().make_nonwriteable(); + assert!(!py + .eval_bound("array.flags.writeable", None, Some(&locals)) + .unwrap() + .extract::() + .unwrap()) + }) +} + #[cfg(feature = "nalgebra")] #[test] fn matrix_from_numpy() {