diff --git a/CMakeLists.txt b/CMakeLists.txt index e7a3c2781..c89eb483e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_BUILD_TYPE Release) set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) -set(CMAKE_CXX_FLAGS_RELEASE "-O3 -funroll-loops -fopenmp -ftree-vectorize -Wno-deprecated") +set(CMAKE_CXX_FLAGS_RELEASE "-O3 -funroll-loops -ftree-vectorize -Wno-deprecated") set(CMAKE_CXX_FLAGS_DEBUG "-Og") set(CMAKE_FIND_PACKAGE_PREFER_CONFIG TRUE) set(CMAKE_FIND_PACKAGE_TARGETS_GLOBAL TRUE) diff --git a/python/bayesmixpy/run.py b/python/bayesmixpy/run.py index 34f0b19e2..c330299c5 100644 --- a/python/bayesmixpy/run.py +++ b/python/bayesmixpy/run.py @@ -9,7 +9,7 @@ from .shell_utils import get_env_file, run_shell from .io_utils import maybe_print_proto_to_file, read_many_protos_from_file -from .proto.algorithm_state_pb2 import AlgorithmState + @@ -104,6 +104,9 @@ def run_mcmc( See https://bayesmix.readthedocs.io/en/latest/protos.html#algorithm_state.proto for furhter details on the AlgorithmState object """ + if return_chains is not None: + from .proto.algorithm_state_pb2 import AlgorithmState + load_dotenv(get_env_file()) BAYESMIX_EXE = os.getenv("BAYESMIX_EXE") if BAYESMIX_EXE is None: diff --git a/python/tests/test_run.py b/python/tests/test_run.py index fcf66e066..5bb76ce89 100644 --- a/python/tests/test_run.py +++ b/python/tests/test_run.py @@ -37,32 +37,38 @@ def test_run_mcmc(): data = get_data() grid = get_grid() - eval_dens, nclus, clus, best_clus = run_mcmc( + eval_dens, nclus, clus, best_clus, chains = run_mcmc( "NNIG", "DP", data, GO_PARAMS, DP_PARAMS, ALGO_PARAMS, grid, return_clusters=False, - return_num_clusters=False, return_best_clus=False) + return_num_clusters=False, return_best_clus=False, + return_chains=False) assert eval_dens.shape[0] == 5 assert eval_dens.shape[1] == len(grid) assert nclus is None assert clus is None assert best_clus is None + assert chains is None - eval_dens, nclus, clus, best_clus = run_mcmc( + eval_dens, nclus, clus, best_clus, chains = run_mcmc( "NNIG", "DP", data, GO_PARAMS, DP_PARAMS, ALGO_PARAMS, None, return_clusters=False, - return_num_clusters=True, return_best_clus=False) + return_num_clusters=True, return_best_clus=False, + return_chains=False) assert eval_dens is None assert nclus is not None assert len(nclus) == 5 assert clus is None assert best_clus is None + assert chains is None - eval_dens, nclus, clus, best_clus = run_mcmc( + eval_dens, nclus, clus, best_clus, chains = run_mcmc( "NNIG", "DP", data, GO_PARAMS, DP_PARAMS, ALGO_PARAMS, None, return_clusters=True, - return_num_clusters=False, return_best_clus=False) + return_num_clusters=False, return_best_clus=False, + return_chains=False) + assert chains is None assert eval_dens is None assert nclus is None @@ -70,14 +76,28 @@ def test_run_mcmc(): assert clus.shape[0] == 5 assert clus.shape[1] == len(data) assert best_clus is None + assert chains is None - eval_dens, nclus, clus, best_clus = run_mcmc( + eval_dens, nclus, clus, best_clus, chains = run_mcmc( "NNIG", "DP", data, GO_PARAMS, DP_PARAMS, ALGO_PARAMS, None, return_clusters=False, - return_num_clusters=False, return_best_clus=True) + return_num_clusters=False, return_best_clus=True, + return_chains=False) assert eval_dens is None assert nclus is None assert clus is None assert best_clus is not None assert len(best_clus) == len(data) + assert chains is None + + + eval_dens, nclus, clus, best_clus, chains = run_mcmc( + "NNIG", "DP", data, GO_PARAMS, DP_PARAMS, + ALGO_PARAMS, None, return_clusters=False, + return_num_clusters=False, return_best_clus=False) + assert eval_dens is None + assert nclus is None + assert clus is None + assert best_clus is None + assert len(chains) == 5