Skip to content

Commit

Permalink
Experiment with making float/int types inherit from np.floating/np.in…
Browse files Browse the repository at this point in the history
…teger.

PiperOrigin-RevId: 719426592
  • Loading branch information
Jake VanderPlas authored and The ml_dtypes Authors committed Jan 24, 2025
1 parent f1439a9 commit ad01eeb
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ bool RegisterFloatDtype(PyObject* numpy) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
// the base type directly when dropping Python 3.9 support.
Safe_PyObjectPtr bases(
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyGenericArrType_Type)));
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyFloatingArrType_Type)));
PyObject* type =
PyType_FromSpecWithBases(&CustomFloatType<T>::type_spec, bases.get());
if (!type) {
Expand Down
2 changes: 1 addition & 1 deletion ml_dtypes/_src/intn_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ bool RegisterIntNDtype(PyObject* numpy) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
// the base type directly when dropping Python 3.9 support.
Safe_PyObjectPtr bases(
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyGenericArrType_Type)));
PyTuple_Pack(1, reinterpret_cast<PyObject*>(&PyIntegerArrType_Type)));
PyObject* type =
PyType_FromSpecWithBases(&IntNTypeDescriptor<T>::type_spec, bases.get());
if (!type) {
Expand Down
2 changes: 1 addition & 1 deletion ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def testDeepCopyDoesNotAlterHash(self, float_type):
def testArray(self, float_type):
x = np.array([[1, 2, 4]], dtype=float_type)
self.assertEqual(float_type, x.dtype)
self.assertEqual("[[1 2 4]]", str(x))
self.assertEqual("[[1. 2. 4.]]", str(x))
np.testing.assert_equal(x, x)
numpy_assert_allclose(x, x, float_type=float_type)
self.assertTrue((x == x).all())
Expand Down
2 changes: 1 addition & 1 deletion ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def testDeepCopyDoesNotAlterHash(self, scalar_type):
def testArray(self, scalar_type):
if scalar_type == int2:
x = np.array([[-2, 1, 0, 1]], dtype=scalar_type)
self.assertEqual("[[-2 1 0 1]]", str(x))
self.assertEqual("[[-2 1 0 1]]", str(x))
else:
x = np.array([[1, 2, 3]], dtype=scalar_type)
self.assertEqual("[[1 2 3]]", str(x))
Expand Down

0 comments on commit ad01eeb

Please sign in to comment.