Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to clear WRITEABLE flag from PyArray #462

Merged
merged 2 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ impl<T: Element, D: Dimension> PyArray<T, D> {

/// Creates a NumPy array backed by `array` and ties its ownership to the Python object `container`.
///
/// The resulting NumPy array will be writeable from Python space. If this is undesireable, use
/// [PyReadwriteArray::make_nonwriteable].
///
/// # Safety
///
/// `container` is set as a base object of the returned array which must not be dropped until `container` is dropped.
Expand Down
32 changes: 31 additions & 1 deletion src/borrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
//! ```rust
//! # use std::panic::{catch_unwind, AssertUnwindSafe};
//! #
//! use numpy::{PyArray1, PyArrayMethods};
//! use numpy::{PyArray1, PyArrayMethods, npyffi::flags};
//! use ndarray::Zip;
//! use pyo3::{Python, Bound};
//!
Expand Down Expand Up @@ -175,6 +175,7 @@ use crate::array::{PyArray, PyArrayMethods};
use crate::convert::NpyIndex;
use crate::dtype::Element;
use crate::error::{BorrowError, NotContiguousError};
use crate::npyffi::flags;
use crate::untyped_array::PyUntypedArrayMethods;

use shared::{acquire, acquire_mut, release, release_mut};
Expand Down Expand Up @@ -453,6 +454,18 @@ where
unsafe { &*(self as *const Self as *const Self::Target) }
}
}
impl<'py, T, D> From<PyReadwriteArray<'py, T, D>> for PyReadonlyArray<'py, T, D>
where
T: Element,
D: Dimension,
{
fn from(value: PyReadwriteArray<'py, T, D>) -> Self {
let array = value.array.clone();
::std::mem::drop(value);
Self::try_new(array)
.expect("releasing an exclusive reference should immediately permit a shared reference")
}
}

impl<'py, T: Element, D: Dimension> FromPyObject<'py> for PyReadwriteArray<'py, T, D> {
fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
Expand Down Expand Up @@ -494,6 +507,23 @@ where
{
unsafe { self.array.get_mut(index) }
}

/// Clear the [`WRITEABLE` flag][writeable] from the underlying NumPy array.
///
/// Calling this will prevent any further [PyReadwriteArray]s from being taken out. Python
/// space can reset this flag, unless the additional flag [`OWNDATA`][owndata] is unset. Such
/// an array can be created from Rust space by using [PyArray::borrow_from_array_bound].
///
/// [writeable]: https://numpy.org/doc/stable/reference/c-api/array.html#c.NPY_ARRAY_WRITEABLE
/// [owndata]: https://numpy.org/doc/stable/reference/c-api/array.html#c.NPY_ARRAY_OWNDATA
pub fn make_nonwriteable(self) -> PyReadonlyArray<'py, T, D> {
// SAFETY: consuming the only extant mutable reference guarantees we cannot invalidate an
// existing reference, nor allow the caller to keep hold of one.
unsafe {
(*self.as_array_ptr()).flags &= !flags::NPY_ARRAY_WRITEABLE;
}
self.into()
}
}

#[cfg(feature = "nalgebra")]
Expand Down
14 changes: 14 additions & 0 deletions tests/borrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,20 @@ fn resize_using_exclusive_borrow() {
});
}

#[test]
fn can_make_python_array_nonwriteable() {
Python::with_gil(|py| {
let array = PyArray1::<f64>::zeros_bound(py, 10, false);
let locals = [("array", &array)].into_py_dict_bound(py);
array.readwrite().make_nonwriteable();
assert!(!py
.eval_bound("array.flags.writeable", None, Some(&locals))
.unwrap()
.extract::<bool>()
.unwrap())
})
}

#[cfg(feature = "nalgebra")]
#[test]
fn matrix_from_numpy() {
Expand Down