Skip to content

Commit

Permalink
Fix checkpointing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jngrad committed Aug 23, 2024
1 parent 1ef5b12 commit 940373f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/python/espressomd/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_last_checkpoint_index(self):
def save(self, checkpoint_index=None):
"""
Saves all registered python objects in the given checkpoint directory
using cPickle.
using pickle.
"""
# get attributes of registered objects
Expand Down
23 changes: 12 additions & 11 deletions src/python/espressomd/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,19 @@ def __init__(self, **kwargs):
if "sip" in kwargs:
super().__init__(**kwargs)
self._setup_atexit()
return
super().__init__(_regular_constructor=True, **kwargs)
if has_features("CUDA"):
self.cuda_init_handle = cuda_init.CudaInitHandle()
if has_features("WALBERLA"):
self._lb = None
self._ekcontainer = None
self._ase_interface = None
else:
super().__init__(_regular_constructor=True, **kwargs)
if has_features("CUDA"):
self.cuda_init_handle = cuda_init.CudaInitHandle()
if has_features("WALBERLA"):
self._lb = None
self._ekcontainer = None

# lock class
self.call_method("lock_system_creation")
self._setup_atexit()
# lock class
self.call_method("lock_system_creation")
self._setup_atexit()

self._ase_interface = None

def _setup_atexit(self):
import atexit
Expand Down
2 changes: 1 addition & 1 deletion testsuite/python/save_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
system.electrostatics.solver = p3m
p3m.charge_neutrality_tolerance = 5e-12

if "ase" in sys.modules:
if has_ase and "ase" in sys.modules:
system.ase = espressomd.plugins.ase.ASEInterface(
type_mapping={0: "H", 1: "O", 10: "Cl"},
)
Expand Down
13 changes: 7 additions & 6 deletions testsuite/python/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class CheckpointTest(ut.TestCase):
checkpoint = espressomd.checkpointing.Checkpoint(
**config.get_checkpoint_params())
checkpoint.load(0)
checkpoint.save(1)
path_cpt_root = pathlib.Path(checkpoint.checkpoint_dir)

@classmethod
Expand All @@ -78,7 +79,7 @@ def setUpClass(cls):
cls.ref_periodicity = np.array([False, False, False])

@utx.skipIfMissingFeatures(["WALBERLA"])
@ut.skipIf(not has_lb_mode, "Skipping test due to missing LB feature.")
@ut.skipIf(not has_lb_mode, "Skipping test due to missing LB mode.")
def test_lb_fluid(self):
lbf = system.lb
cpt_mode = 0 if 'LB.ASCII' in modes else 1
Expand Down Expand Up @@ -164,7 +165,7 @@ def test_lb_fluid(self):
np.copy(lbf[:, :, :].is_boundary.astype(int)), 0)

@utx.skipIfMissingFeatures(["WALBERLA"])
@ut.skipIf(not has_lb_mode, "Skipping test due to missing EK feature.")
@ut.skipIf(not has_lb_mode, "Skipping test due to missing EK mode.")
def test_ek_species(self):
cpt_mode = 0 if 'LB.ASCII' in modes else 1
cpt_root = pathlib.Path(self.checkpoint.checkpoint_dir)
Expand Down Expand Up @@ -269,7 +270,7 @@ def generator(value, shape):

@utx.skipIfMissingFeatures(["WALBERLA"])
@ut.skipIf('LB.GPU' in modes, 'VTK not implemented for LB GPU')
@ut.skipIf(not has_lb_mode, "Skipping test due to missing LB feature.")
@ut.skipIf(not has_lb_mode, "Skipping test due to missing LB mode.")
def test_lb_vtk(self):
lbf = system.lb
self.assertEqual(len(lbf.vtk_writers), 2)
Expand Down Expand Up @@ -316,7 +317,7 @@ def test_lb_vtk(self):
(vtk_root / filename.format(2)).unlink(missing_ok=True)

@utx.skipIfMissingFeatures(["WALBERLA"])
@ut.skipIf(not has_lb_mode, "Skipping test due to missing EK feature.")
@ut.skipIf(not has_lb_mode, "Skipping test due to missing EK mode.")
def test_ek_vtk(self):
ek_species = system.ekcontainer[0]
vtk_suffix = config.test_name
Expand Down Expand Up @@ -507,7 +508,7 @@ def test_shape_based_constraints_serialization(self):
self.assertGreater(np.linalg.norm(np.copy(p3.f) - old_force), 1e6)

@utx.skipIfMissingFeatures(["WALBERLA"])
@ut.skipIf(not has_lb_mode, "Skipping test due to missing LB feature.")
@ut.skipIf(not has_lb_mode, "Skipping test due to missing LB mode.")
@ut.skipIf('THERM.LB' not in modes, 'LB thermostat not in modes')
def test_thermostat_LB(self):
thmst = system.thermostat.lb
Expand Down Expand Up @@ -916,7 +917,7 @@ def test_scafacos_dipoles(self):
self.assertEqual(state[key], reference[key], msg=f'for {key}')

def test_comfixed(self):
self.assertEqual(list(system.comfixed.types), [0, 2])
self.assertEqual(set(system.comfixed.types), {0, 2})

@utx.skipIfMissingFeatures('COLLISION_DETECTION')
def test_collision_detection(self):
Expand Down

0 comments on commit 940373f

Please sign in to comment.