diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index 9f292eba..a4ffe6a5 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -850,7 +850,7 @@ bool RegisterFloatDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass // the base type directly when dropping Python 3.9 support. Safe_PyObjectPtr bases( - PyTuple_Pack(1, reinterpret_cast(&PyGenericArrType_Type))); + PyTuple_Pack(1, reinterpret_cast(&PyFloatingArrType_Type))); PyObject* type = PyType_FromSpecWithBases(&CustomFloatType::type_spec, bases.get()); if (!type) { diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index e184e0b0..39e7ee73 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -769,7 +769,7 @@ bool RegisterIntNDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass // the base type directly when dropping Python 3.9 support. Safe_PyObjectPtr bases( - PyTuple_Pack(1, reinterpret_cast(&PyGenericArrType_Type))); + PyTuple_Pack(1, reinterpret_cast(&PyIntegerArrType_Type))); PyObject* type = PyType_FromSpecWithBases(&IntNTypeDescriptor::type_spec, bases.get()); if (!type) { diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index 7eb313bb..e06e1cec 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -713,7 +713,7 @@ def testDeepCopyDoesNotAlterHash(self, float_type): def testArray(self, float_type): x = np.array([[1, 2, 4]], dtype=float_type) self.assertEqual(float_type, x.dtype) - self.assertEqual("[[1 2 4]]", str(x)) + self.assertEqual("[[1. 2. 4.]]", str(x)) np.testing.assert_equal(x, x) numpy_assert_allclose(x, x, float_type=float_type) self.assertTrue((x == x).all()) diff --git a/ml_dtypes/tests/intn_test.py b/ml_dtypes/tests/intn_test.py index 86ab5a81..df02ab71 100644 --- a/ml_dtypes/tests/intn_test.py +++ b/ml_dtypes/tests/intn_test.py @@ -274,7 +274,7 @@ def testDeepCopyDoesNotAlterHash(self, scalar_type): def testArray(self, scalar_type): if scalar_type == int2: x = np.array([[-2, 1, 0, 1]], dtype=scalar_type) - self.assertEqual("[[-2 1 0 1]]", str(x)) + self.assertEqual("[[-2 1 0 1]]", str(x)) else: x = np.array([[1, 2, 3]], dtype=scalar_type) self.assertEqual("[[1 2 3]]", str(x))