Skip to content

Commit

Permalink
Added missing information to serialization proxy (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman authored May 28, 2021
1 parent 6472507 commit 76f55e7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
15 changes: 15 additions & 0 deletions serialization/src/TorchForceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,26 @@ void TorchForceProxy::serialize(const void* object, SerializationNode& node) con
node.setIntProperty("version", 1);
const TorchForce& force = *reinterpret_cast<const TorchForce*>(object);
node.setStringProperty("file", force.getFile());
node.setIntProperty("forceGroup", force.getForceGroup());
node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions());
SerializationNode& globalParams = node.createChildNode("GlobalParameters");
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParams.createChildNode("Parameter").setStringProperty("name", force.getGlobalParameterName(i)).setDoubleProperty("default", force.getGlobalParameterDefaultValue(i));
}
}

void* TorchForceProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
TorchForce* force = new TorchForce(node.getStringProperty("file"));
if (node.hasProperty("forceGroup"))
force->setForceGroup(node.getIntProperty("forceGroup", 0));
if (node.hasProperty("usesPeriodic"))
force->setUsesPeriodicBoundaryConditions(node.getBoolProperty("usesPeriodic"));
for (const SerializationNode& child : node.getChildren()) {
if (child.getName() == "GlobalParameters")
for (auto& parameter : child.getChildren())
force->addGlobalParameter(parameter.getStringProperty("name"), parameter.getDoubleProperty("default"));
}
return force;
}
11 changes: 11 additions & 0 deletions serialization/tests/TestSerializeTorchForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ void testSerialization() {
// Create a Force.

TorchForce force("module.pt");
force.setForceGroup(3);
force.addGlobalParameter("x", 1.3);
force.addGlobalParameter("y", 2.221);
force.setUsesPeriodicBoundaryConditions(true);

// Serialize and then deserialize it.

Expand All @@ -57,6 +61,13 @@ void testSerialization() {

TorchForce& force2 = *copy;
ASSERT_EQUAL(force.getFile(), force2.getFile());
ASSERT_EQUAL(force.getForceGroup(), force2.getForceGroup());
ASSERT_EQUAL(force.getNumGlobalParameters(), force2.getNumGlobalParameters());
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
ASSERT_EQUAL(force.getGlobalParameterName(i), force2.getGlobalParameterName(i));
ASSERT_EQUAL(force.getGlobalParameterDefaultValue(i), force2.getGlobalParameterDefaultValue(i));
}
ASSERT_EQUAL(force.usesPeriodicBoundaryConditions(), force2.usesPeriodicBoundaryConditions());
}

int main() {
Expand Down

0 comments on commit 76f55e7

Please sign in to comment.