diff --git a/bindings/pylibROM/algo/greedy/pyGreedySampler.cpp b/bindings/pylibROM/algo/greedy/pyGreedySampler.cpp index 29e2f70..5112765 100644 --- a/bindings/pylibROM/algo/greedy/pyGreedySampler.cpp +++ b/bindings/pylibROM/algo/greedy/pyGreedySampler.cpp @@ -123,7 +123,14 @@ void init_GreedySampler(pybind11::module_ &m) { .def("setPointRelativeError", (void (GreedySampler::*) (double))&GreedySampler::setPointRelativeError) .def("setPointErrorIndicator", (void (GreedySampler::*) (double,int)) &GreedySampler::setPointErrorIndicator) .def("getNearestNonSampledPoint", (int (GreedySampler::*) (CAROM::Vector)) &GreedySampler::getNearestNonSampledPoint) - .def("getNearestROM", &GreedySampler::getNearestROM) + .def("getNearestROM", [](GreedySampler& self, Vector point) -> std::unique_ptr { + std::shared_ptr result = self.getNearestROM(point); + if (!result) + { + return nullptr; + } + return std::make_unique(*(result.get())); + }) .def("getParameterPointDomain", &GreedySampler::getParameterPointDomain) .def("getSampledParameterPoints", &GreedySampler::getSampledParameterPoints) .def("save", &GreedySampler::save) diff --git a/tests/test_pyGreedyCustomSampler.py b/tests/test_pyGreedyCustomSampler.py index 5b3ade5..3245694 100644 --- a/tests/test_pyGreedyCustomSampler.py +++ b/tests/test_pyGreedyCustomSampler.py @@ -149,6 +149,8 @@ def test_greedy_save_and_load(): closestROM = caromGreedySampler.getNearestROM(pointToFindNearestROM) closestROMLoad = caromGreedySamplerLoad.getNearestROM(pointToFindNearestROM) + # there were no points sampled, so closestROM should be None + assert closestROM is None assert closestROM == closestROMLoad nextPointToSample = caromGreedySampler.getNextParameterPoint() @@ -158,5 +160,38 @@ def test_greedy_save_and_load(): assert nextPointToSample.item(0) == nextPointToSampleLoad.item(0) +def test_greedy_save_and_load_with_sample(): + paramPoints = [1.0, 2.0, 3.0, 99.0, 100., 101.0] + + caromGreedySampler = greedy.GreedyCustomSampler(paramPoints, False, 0.1, 1, 1, 3, 4, "", "", False, 1, True) + + nextPointToSample = caromGreedySampler.getNextParameterPoint() + assert nextPointToSample.dim() == 1 + assert nextPointToSample.item(0) == 3.0 + + # save after sampling a point to test if sampled points are restored + caromGreedySampler.save("greedy_test") + + caromGreedySamplerLoad = greedy.GreedyCustomSampler("greedy_test") + caromGreedySamplerLoad.save("greedy_test_LOAD") + + pointToFindNearestROM = linalg.Vector(1, False) + pointToFindNearestROM[0] = 1.0 + + closestROM = caromGreedySampler.getNearestROM(pointToFindNearestROM) + closestROMLoad = caromGreedySamplerLoad.getNearestROM(pointToFindNearestROM) + + assert closestROM is not None + assert closestROM.dim() == 1 + assert closestROM.dim() == closestROMLoad.dim() + assert closestROM.item(0) == 1.0 + assert closestROM.item(0) == closestROMLoad.item(0) + + nextPointToSample = caromGreedySampler.getNextParameterPoint() + nextPointToSampleLoad = caromGreedySamplerLoad.getNextParameterPoint() + + assert nextPointToSample.dim() == nextPointToSampleLoad.dim() + assert nextPointToSample.item(0) == nextPointToSampleLoad.item(0) + if __name__ == '__main__': pytest.main()