diff --git a/serialization/src/TorchForceProxy.cpp b/serialization/src/TorchForceProxy.cpp index cb3a3af9..1d41a396 100644 --- a/serialization/src/TorchForceProxy.cpp +++ b/serialization/src/TorchForceProxy.cpp @@ -45,11 +45,26 @@ void TorchForceProxy::serialize(const void* object, SerializationNode& node) con node.setIntProperty("version", 1); const TorchForce& force = *reinterpret_cast(object); node.setStringProperty("file", force.getFile()); + node.setIntProperty("forceGroup", force.getForceGroup()); + node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions()); + SerializationNode& globalParams = node.createChildNode("GlobalParameters"); + for (int i = 0; i < force.getNumGlobalParameters(); i++) { + globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i)); + } } void* TorchForceProxy::deserialize(const SerializationNode& node) const { if (node.getIntProperty("version") != 1) throw OpenMMException("Unsupported version number"); TorchForce* force = new TorchForce(node.getStringProperty("file")); + if (node.hasProperty("forceGroup")) + force->setForceGroup(node.getIntProperty("forceGroup", 0)); + if (node.hasProperty("usesPeriodic")) + force->setUsesPeriodicBoundaryConditions(node.getBoolProperty("usesPeriodic")); + for (const SerializationNode& child : node.getChildren()) { + if (child.getName() == "GlobalParameters") + for (auto& parameter : child.getChildren()) + force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default")); + } return force; } diff --git a/serialization/tests/TestSerializeTorchForce.cpp b/serialization/tests/TestSerializeTorchForce.cpp index 4d4cc0a3..3a58019f 100644 --- a/serialization/tests/TestSerializeTorchForce.cpp +++ b/serialization/tests/TestSerializeTorchForce.cpp @@ -46,6 +46,10 @@ void testSerialization() { // Create a Force. TorchForce force("module.pt"); + force.setForceGroup(3); + force.addGlobalParameter("x", 1.3); + force.addGlobalParameter("y", 2.221); + force.setUsesPeriodicBoundaryConditions(true); // Serialize and then deserialize it. @@ -57,6 +61,13 @@ void testSerialization() { TorchForce& force2 = *copy; ASSERT_EQUAL(force.getFile(), force2.getFile()); + ASSERT_EQUAL(force.getForceGroup(), force2.getForceGroup()); + ASSERT_EQUAL(force.getNumGlobalParameters(), force2.getNumGlobalParameters()); + for (int i = 0; i < force.getNumGlobalParameters(); i++) { + ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i)); + ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i)); + } + ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions()); } int main() {