diff --git a/src/array.rs b/src/array.rs index 73670c158..40b8eb361 100644 --- a/src/array.rs +++ b/src/array.rs @@ -122,7 +122,7 @@ pub type PyArrayDyn = PyArray; /// Returns a handle to NumPy's multiarray module. pub fn get_array_module<'py>(py: Python<'py>) -> PyResult> { - PyModule::import_bound(py, npyffi::array::MOD_NAME) + PyModule::import_bound(py, npyffi::array::mod_name(py)?) } impl DerefToPyAny for PyArray {} diff --git a/src/npyffi/array.rs b/src/npyffi/array.rs index a0971c2a3..9689e89df 100644 --- a/src/npyffi/array.rs +++ b/src/npyffi/array.rs @@ -14,7 +14,43 @@ use pyo3::{ use crate::npyffi::*; -pub(crate) const MOD_NAME: &str = "numpy._core.multiarray"; +pub(crate) fn numpy_core_name(py: Python<'_>) -> PyResult<&'static str> { + static MOD_NAME: GILOnceCell<&'static str> = GILOnceCell::new(); + + MOD_NAME + .get_or_try_init(py, || { + // numpy 2 renamed to numpy._core + + // strategy mirrored from https://github.com/pybind/pybind11/blob/af67e87393b0f867ccffc2702885eea12de063fc/include/pybind11/numpy.h#L175-L195 + + let numpy = PyModule::import_bound(py, "numpy")?; + let version_string = numpy.getattr("__version__")?; + + let numpy_lib = PyModule::import_bound(py, "numpy.lib")?; + let numpy_version = numpy_lib + .getattr("NumpyVersion")? + .call1((version_string,))?; + let major_version: u8 = numpy_version.getattr("major")?.extract()?; + + Ok(if major_version >= 2 { + "numpy._core" + } else { + "numpy.core" + }) + }) + .copied() +} + +pub(crate) fn mod_name(py: Python<'_>) -> PyResult<&'static str> { + static MOD_NAME: GILOnceCell = GILOnceCell::new(); + MOD_NAME + .get_or_try_init(py, || { + let numpy_core = numpy_core_name(py)?; + Ok(format!("{}.multiarray", numpy_core)) + }) + .map(String::as_str) +} + const CAPSULE_NAME: &str = "_ARRAY_API"; /// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html) @@ -49,7 +85,7 @@ impl PyArrayAPI { unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> *const *const c_void { let api = self .0 - .get_or_try_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME)) + .get_or_try_init(py, || get_numpy_api(py, mod_name(py)?, CAPSULE_NAME)) .expect("Failed to access NumPy array API capsule"); api.offset(offset) diff --git a/src/npyffi/ufunc.rs b/src/npyffi/ufunc.rs index 9f90e73f3..729a85f52 100644 --- a/src/npyffi/ufunc.rs +++ b/src/npyffi/ufunc.rs @@ -6,7 +6,16 @@ use pyo3::{ffi::PyObject, sync::GILOnceCell}; use crate::npyffi::*; -const MOD_NAME: &str = "numpy.core.umath"; +fn mod_name(py: Python<'_>) -> PyResult<&'static str> { + static MOD_NAME: GILOnceCell = GILOnceCell::new(); + MOD_NAME + .get_or_try_init(py, || { + let numpy_core = super::array::numpy_core_name(py)?; + Ok(format!("{}.umath", numpy_core)) + }) + .map(String::as_str) +} + const CAPSULE_NAME: &str = "_UFUNC_API"; /// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html) @@ -23,7 +32,7 @@ impl PyUFuncAPI { unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> *const *const c_void { let api = self .0 - .get_or_try_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME)) + .get_or_try_init(py, || get_numpy_api(py, mod_name(py)?, CAPSULE_NAME)) .expect("Failed to access NumPy ufunc API capsule"); api.offset(offset)