Skip to content

Commit

Permalink
Experiment with making float/int types inherit from np.floating/np.in…
Browse files Browse the repository at this point in the history
…teger.

PiperOrigin-RevId: 719426592
  • Loading branch information
Jake VanderPlas authored and The ml_dtypes Authors committed Jan 24, 2025
1 parent f1439a9 commit e3235e2
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ bool RegisterFloatDtype(PyObject* numpy) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
// the base type directly when dropping Python 3.9 support.
Safe_PyObjectPtr bases(
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyGenericArrType_Type)));
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyFloatingArrType_Type)));
PyObject* type =
PyType_FromSpecWithBases(&CustomFloatType<T>::type_spec, bases.get());
if (!type) {
Expand Down
8 changes: 4 additions & 4 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,10 @@ bool Initialize() {
return false;
}

if (!RegisterIntNDtype<int2>(numpy.get()) ||
!RegisterIntNDtype<uint2>(numpy.get()) ||
!RegisterIntNDtype<int4>(numpy.get()) ||
!RegisterIntNDtype<uint4>(numpy.get())) {
if (!RegisterIntNDtype<int2>(numpy.get(), /* is_signed= */ true) ||
!RegisterIntNDtype<uint2>(numpy.get(), /* is_signed= */ false) ||
!RegisterIntNDtype<int4>(numpy.get(), /* is_signed= */ true) ||
!RegisterIntNDtype<uint4>(numpy.get(), /* is_signed= */ false)) {
return false;
}

Expand Down
8 changes: 5 additions & 3 deletions ml_dtypes/_src/intn_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -765,11 +765,13 @@ bool RegisterIntNUFuncs(PyObject* numpy) {
}

template <typename T>
bool RegisterIntNDtype(PyObject* numpy) {
bool RegisterIntNDtype(PyObject* numpy, bool is_signed) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
// the base type directly when dropping Python 3.9 support.
Safe_PyObjectPtr bases(
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyGenericArrType_Type)));
Safe_PyObjectPtr bases(PyTuple_Pack(
1, is_signed
? (reinterpret_cast<PyObject*>(&PySignedIntegerArrType_Type))
: (reinterpret_cast<PyObject*>(&PyUnsignedIntegerArrType_Type))));
PyObject* type =
PyType_FromSpecWithBases(&IntNTypeDescriptor<T>::type_spec, bases.get());
if (!type) {
Expand Down
2 changes: 1 addition & 1 deletion ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def testDeepCopyDoesNotAlterHash(self, float_type):
def testArray(self, float_type):
x = np.array([[1, 2, 4]], dtype=float_type)
self.assertEqual(float_type, x.dtype)
self.assertEqual("[[1 2 4]]", str(x))
self.assertEqual("[[1. 2. 4.]]", str(x))
np.testing.assert_equal(x, x)
numpy_assert_allclose(x, x, float_type=float_type)
self.assertTrue((x == x).all())
Expand Down
2 changes: 1 addition & 1 deletion ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def testDeepCopyDoesNotAlterHash(self, scalar_type):
def testArray(self, scalar_type):
if scalar_type == int2:
x = np.array([[-2, 1, 0, 1]], dtype=scalar_type)
self.assertEqual("[[-2 1 0 1]]", str(x))
self.assertEqual("[[-2 1 0 1]]", str(x))
else:
x = np.array([[1, 2, 3]], dtype=scalar_type)
self.assertEqual("[[1 2 3]]", str(x))
Expand Down

0 comments on commit e3235e2

Please sign in to comment.