diff --git a/clif/testing/python/t3_test.py b/clif/testing/python/t3_test.py index 0dbb1a9..d2eaefe 100644 --- a/clif/testing/python/t3_test.py +++ b/clif/testing/python/t3_test.py @@ -14,14 +14,16 @@ """Tests for clif.testing.python.t3.""" +import pickle import sys from absl.testing import absltest +from absl.testing import parameterized from clif.testing.python import t3 -class T3Test(absltest.TestCase): +class T3Test(parameterized.TestCase): def testEnumBasic(self): self.assertEqual(t3._Old.TOP1, 1) @@ -37,12 +39,16 @@ def testEnumBasic(self): def testEnumModuleName(self): # This is necessary for proper pickling. - if t3.__pyclif_codegen_mode__ == 'pybind11': - self.assertEqual(t3._Old.__module__.__name__, t3.__name__) - self.assertEqual(t3._New.__module__.__name__, t3.__name__) - else: - self.assertEqual(t3._Old.__module__, t3.__name__) - self.assertEqual(t3._New.__module__, t3.__name__) + self.assertEqual(t3._Old.__module__, t3.__name__) + self.assertEqual(t3._New.__module__, t3.__name__) + + @parameterized.parameters( + t3._Old.TOP1, t3._Old.TOPn, t3._New.TOP, t3._New.BOTTOM + ) + def testPickleRoundtrip(self, orig): + serialized = pickle.dumps(orig) + restored = pickle.loads(serialized) + self.assertEqual(restored, orig) def testEnumBases(self): old_bases = t3._Old.__bases__