Skip to content

Commit

Permalink
Merge pull request #203 from vfdev-5:reenable-mt-tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716020733
  • Loading branch information
The ml_dtypes Authors committed Jan 16, 2025
2 parents 5f1240a + 4ac1090 commit f1439a9
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 34 deletions.
48 changes: 28 additions & 20 deletions ml_dtypes/tests/custom_float_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded
import numpy as np

bfloat16 = ml_dtypes.bfloat16
Expand Down Expand Up @@ -221,12 +221,16 @@ def dtype_is_signed(dtype):
}


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# pylint: disable=g-complex-comprehension
# @multi_threaded(
# num_workers=3,
# skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"],
# )
@multi_threaded(
num_workers=3,
skip_tests=[
"testDiv",
"testPickleable",
"testRoundTripNumpyTypes",
"testRoundTripToNumpy",
],
)
@parameterized.named_parameters(
(
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
Expand Down Expand Up @@ -661,21 +665,25 @@ def testDtypeFromString(self, float_type):
]


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# pylint: disable=g-complex-comprehension
# @multi_threaded(
# num_workers=3,
# skip_tests=[
# "testBinaryUfunc",
# "testConformNumpyComplex",
# "testFloordivCornerCases",
# "testDivmodCornerCases",
# "testSpacing",
# "testUnaryUfunc",
# "testCasts",
# "testLdexp",
# ],
# )
@multi_threaded(
num_workers=3,
skip_tests=[
"testBinaryPredicateUfunc",
"testBinaryUfunc",
"testCasts",
"testConformNumpyComplex",
"testDivmod",
"testDivmodCornerCases",
"testFloordivCornerCases",
"testFrexp",
"testLdexp",
"testModf",
"testPredicateUfunc",
"testSpacing",
"testUnaryUfunc",
],
)
@parameterized.named_parameters(
(
{"testcase_name": "_" + dtype.__name__, "float_type": dtype}
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/tests/finfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded
import numpy as np

ALL_DTYPES = [
Expand Down Expand Up @@ -55,8 +55,7 @@
}


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
@multi_threaded(num_workers=3)
class FinfoTest(parameterized.TestCase):

def assertNanEqual(self, x, y):
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/tests/iinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded
import numpy as np


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
@multi_threaded(num_workers=3)
class IinfoTest(parameterized.TestCase):

def testIinfoInt2(self):
Expand Down
8 changes: 3 additions & 5 deletions ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded
import numpy as np

int2 = ml_dtypes.int2
Expand All @@ -48,9 +48,8 @@ def ignore_warning(**kw):
yield


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# Tests for the Python scalar type
# @multi_threaded(num_workers=3)
@multi_threaded(num_workers=3)
class ScalarTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down Expand Up @@ -247,9 +246,8 @@ def testCanCast(self, a, b):
)


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# Tests for the Python scalar type
# @multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"])
class ArrayTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down
5 changes: 2 additions & 3 deletions ml_dtypes/tests/metadata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

from absl.testing import absltest
import ml_dtypes
# from multi_thread_utils import multi_threaded
from multi_thread_utils import multi_threaded


# TODO(jakevdp): re-enable multi-threaded tests after 0.5.0 release.
# @multi_threaded(num_workers=3)
@multi_threaded(num_workers=3)
class CustomFloatTest(absltest.TestCase):

def test_version_matches_package_metadata(self):
Expand Down

0 comments on commit f1439a9

Please sign in to comment.