Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Numpy 2.x #429

Closed
wants to merge 11 commits into from
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ nalgebra = { version = "0.32", default-features = false, features = ["std"] }

[package.metadata.docs.rs]
all-features = true

[features]
default = ["numpy-1", "numpy-2"]
adamreichold marked this conversation as resolved.
Show resolved Hide resolved
numpy-1 = []
numpy-2 = []
4 changes: 2 additions & 2 deletions src/borrow/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap;
use crate::array::get_array_module;
use crate::cold;
use crate::error::BorrowError;
use crate::npyffi::{PyArrayObject, PyArray_Check, NPY_ARRAY_WRITEABLE};
use crate::npyffi::{PyArrayObject, PyArray_Check, PyDataType_ELSIZE, NPY_ARRAY_WRITEABLE};

/// Defines the shared C API used for borrow checking
///
Expand Down Expand Up @@ -403,7 +403,7 @@ fn data_range(array: *mut PyArrayObject) -> (*mut c_char, *mut c_char) {
let shape = unsafe { from_raw_parts((*array).dimensions as *mut usize, nd) };
let strides = unsafe { from_raw_parts((*array).strides, nd) };

let itemsize = unsafe { (*(*array).descr).elsize } as isize;
let itemsize = unsafe { PyDataType_ELSIZE((*array).descr) } as isize;

let mut start = 0;
let mut end = 0;
Expand Down
6 changes: 4 additions & 2 deletions src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ use pyo3::{sync::GILProtected, Bound, Py, Python};
use rustc_hash::FxHashMap;

use crate::dtype::{Element, PyArrayDescr, PyArrayDescrMethods};
use crate::npyffi::{PyArray_DatetimeDTypeMetaData, NPY_DATETIMEUNIT, NPY_TYPES};
use crate::npyffi::{
PyArray_DatetimeDTypeMetaData, PyDataType_C_METADATA, NPY_DATETIMEUNIT, NPY_TYPES,
};

/// Represents the [datetime units][datetime-units] supported by NumPy
///
Expand Down Expand Up @@ -230,7 +232,7 @@ impl TypeDescriptors {

// SAFETY: `self.npy_type` is either `NPY_DATETIME` or `NPY_TIMEDELTA` which implies the type of `c_metadata`.
unsafe {
let metadata = &mut *((*dtype.as_dtype_ptr()).c_metadata
let metadata = &mut *(PyDataType_C_METADATA(dtype.as_dtype_ptr())
as *mut PyArray_DatetimeDTypeMetaData);

metadata.meta.base = unit;
Expand Down
44 changes: 24 additions & 20 deletions src/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::mem::size_of;
use std::os::raw::{
c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort,
};
use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
use std::ptr;

#[cfg(feature = "half")]
Expand All @@ -19,8 +17,9 @@ use pyo3::{
use pyo3::{sync::GILOnceCell, Py};

use crate::npyffi::{
NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES,
PY_ARRAY_API,
FlagType, NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
};

pub use num_complex::{Complex32, Complex64};
Expand Down Expand Up @@ -256,7 +255,7 @@ impl PyArrayDescr {
/// Equivalent to [`numpy.dtype.flags`][dtype-flags].
///
/// [dtype-flags]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html
pub fn flags(&self) -> c_char {
pub fn flags(&self) -> FlagType {
self.as_borrowed().flags()
}

Expand Down Expand Up @@ -397,7 +396,7 @@ pub trait PyArrayDescrMethods<'py>: Sealed {
/// [dtype-itemsiize]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.itemsize.html

fn itemsize(&self) -> usize {
unsafe { *self.as_dtype_ptr() }.elsize.max(0) as _
PyDataType_ELSIZE(self.as_dtype_ptr()).max(0) as _
}

/// Returns the required alignment (bytes) of this type descriptor according to the compiler.
Expand All @@ -406,7 +405,7 @@ pub trait PyArrayDescrMethods<'py>: Sealed {
///
/// [dtype-alignment]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.alignment.html
fn alignment(&self) -> usize {
unsafe { *self.as_dtype_ptr() }.alignment.max(0) as _
PyDataType_ALIGNMENT(self.as_dtype_ptr()).max(0) as _
}

/// Returns an ASCII character indicating the byte-order of this type descriptor object.
Expand Down Expand Up @@ -447,8 +446,8 @@ pub trait PyArrayDescrMethods<'py>: Sealed {
/// Equivalent to [`numpy.dtype.flags`][dtype-flags].
///
/// [dtype-flags]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html
fn flags(&self) -> c_char {
unsafe { *self.as_dtype_ptr() }.flags
fn flags(&self) -> FlagType {
PyDataType_FLAGS(self.as_dtype_ptr())
}

/// Returns the number of dimensions if this type descriptor represents a sub-array, and zero otherwise.
Expand All @@ -460,7 +459,7 @@ pub trait PyArrayDescrMethods<'py>: Sealed {
if !self.has_subarray() {
return 0;
}
unsafe { PyTuple_Size((*((*self.as_dtype_ptr()).subarray)).shape).max(0) as _ }
unsafe { PyTuple_Size((*PyDataType_SUBARRAY(self.as_dtype_ptr())).shape).max(0) as _ }
}

/// Returns the type descriptor for the base element of subarrays, regardless of their dimension or shape.
Expand Down Expand Up @@ -505,13 +504,13 @@ pub trait PyArrayDescrMethods<'py>: Sealed {
/// Returns true if the type descriptor is a sub-array.
fn has_subarray(&self) -> bool {
// equivalent to PyDataType_HASSUBARRAY(self)
unsafe { !(*self.as_dtype_ptr()).subarray.is_null() }
!PyDataType_SUBARRAY(self.as_dtype_ptr()).is_null()
}

/// Returns true if the type descriptor is a structured type.
fn has_fields(&self) -> bool {
// equivalent to PyDataType_HASFIELDS(self)
unsafe { !(*self.as_dtype_ptr()).names.is_null() }
!PyDataType_NAMES(self.as_dtype_ptr()).is_null()
}

/// Returns true if type descriptor byteorder is native, or `None` if not applicable.
Expand Down Expand Up @@ -581,8 +580,11 @@ impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
self.clone()
} else {
unsafe {
Bound::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).base.cast())
.downcast_into_unchecked()
Bound::from_borrowed_ptr(
self.py(),
(*PyDataType_SUBARRAY(self.as_dtype_ptr())).base.cast(),
)
.downcast_into_unchecked()
}
}
}
Expand All @@ -592,17 +594,19 @@ impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
Vec::new()
} else {
// NumPy guarantees that shape is a tuple of non-negative integers so this should never panic.
unsafe { Borrowed::from_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape) }
.extract()
.unwrap()
unsafe {
Borrowed::from_ptr(self.py(), (*PyDataType_SUBARRAY(self.as_dtype_ptr())).shape)
}
.extract()
.unwrap()
}
}

fn names(&self) -> Option<Vec<String>> {
if !self.has_fields() {
return None;
}
let names = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).names) };
let names = unsafe { Borrowed::from_ptr(self.py(), PyDataType_NAMES(self.as_dtype_ptr())) };
names.extract().ok()
}

Expand All @@ -612,7 +616,7 @@ impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
"cannot get field information: type descriptor has no fields",
));
}
let dict = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
let dict = unsafe { Borrowed::from_ptr(self.py(), PyDataType_FIELDS(self.as_dtype_ptr())) };
let dict = unsafe { dict.downcast_unchecked::<PyDict>() };
// NumPy guarantees that fields are tuples of proper size and type, so this should never panic.
let tuple = dict
Expand Down
Loading