Skip to content

Commit

Permalink
Update bindings and test for getNearestROM
Browse files Browse the repository at this point in the history
  • Loading branch information
ckendrick committed Jan 3, 2024
1 parent f97fa9d commit 4547bd0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
9 changes: 8 additions & 1 deletion bindings/pylibROM/algo/greedy/pyGreedySampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,14 @@ void init_GreedySampler(pybind11::module_ &m) {
.def("setPointRelativeError", (void (GreedySampler::*) (double))&GreedySampler::setPointRelativeError)
.def("setPointErrorIndicator", (void (GreedySampler::*) (double,int)) &GreedySampler::setPointErrorIndicator)
.def("getNearestNonSampledPoint", (int (GreedySampler::*) (CAROM::Vector)) &GreedySampler::getNearestNonSampledPoint)
.def("getNearestROM", &GreedySampler::getNearestROM)
.def("getNearestROM", [](GreedySampler& self, Vector point) -> std::unique_ptr<Vector> {
std::shared_ptr<Vector> result = self.getNearestROM(point);
if (!result)
{
return nullptr;
}
return std::make_unique<Vector>(*(result.get()));
})
.def("getParameterPointDomain", &GreedySampler::getParameterPointDomain)
.def("getSampledParameterPoints", &GreedySampler::getSampledParameterPoints)
.def("save", &GreedySampler::save)
Expand Down
35 changes: 35 additions & 0 deletions tests/test_pyGreedyCustomSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def test_greedy_save_and_load():
closestROM = caromGreedySampler.getNearestROM(pointToFindNearestROM)
closestROMLoad = caromGreedySamplerLoad.getNearestROM(pointToFindNearestROM)

# there were no points sampled, so closestROM should be None
assert closestROM is None
assert closestROM == closestROMLoad

nextPointToSample = caromGreedySampler.getNextParameterPoint()
Expand All @@ -158,5 +160,38 @@ def test_greedy_save_and_load():
assert nextPointToSample.item(0) == nextPointToSampleLoad.item(0)


def test_greedy_save_and_load_with_sample():
paramPoints = [1.0, 2.0, 3.0, 99.0, 100., 101.0]

caromGreedySampler = greedy.GreedyCustomSampler(paramPoints, False, 0.1, 1, 1, 3, 4, "", "", False, 1, True)

nextPointToSample = caromGreedySampler.getNextParameterPoint()
assert nextPointToSample.dim() == 1
assert nextPointToSample.item(0) == 3.0

# save after sampling a point to test if sampled points are restored
caromGreedySampler.save("greedy_test")

caromGreedySamplerLoad = greedy.GreedyCustomSampler("greedy_test")
caromGreedySamplerLoad.save("greedy_test_LOAD")

pointToFindNearestROM = linalg.Vector(1, False)
pointToFindNearestROM[0] = 1.0

closestROM = caromGreedySampler.getNearestROM(pointToFindNearestROM)
closestROMLoad = caromGreedySamplerLoad.getNearestROM(pointToFindNearestROM)

assert closestROM is not None
assert closestROM.dim() == 1
assert closestROM.dim() == closestROMLoad.dim()
assert closestROM.item(0) == 3.0
assert closestROM.item(0) == closestROMLoad.item(0)

nextPointToSample = caromGreedySampler.getNextParameterPoint()
nextPointToSampleLoad = caromGreedySamplerLoad.getNextParameterPoint()

assert nextPointToSample.dim() == nextPointToSampleLoad.dim()
assert nextPointToSample.item(0) == nextPointToSampleLoad.item(0)

if __name__ == '__main__':
pytest.main()

0 comments on commit 4547bd0

Please sign in to comment.