Skip to content

Commit

Permalink
quick hack
Browse files Browse the repository at this point in the history
  • Loading branch information
mberaha committed Sep 21, 2023
1 parent a0bb760 commit 4cb9727
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_BUILD_TYPE Release)

set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH})
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -funroll-loops -fopenmp -ftree-vectorize -Wno-deprecated")
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -funroll-loops -ftree-vectorize -Wno-deprecated")
set(CMAKE_CXX_FLAGS_DEBUG "-Og")
set(CMAKE_FIND_PACKAGE_PREFER_CONFIG TRUE)
set(CMAKE_FIND_PACKAGE_TARGETS_GLOBAL TRUE)
Expand Down
5 changes: 4 additions & 1 deletion python/bayesmixpy/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .shell_utils import get_env_file, run_shell
from .io_utils import maybe_print_proto_to_file, read_many_protos_from_file
from .proto.algorithm_state_pb2 import AlgorithmState




Expand Down Expand Up @@ -104,6 +104,9 @@ def run_mcmc(
See https://bayesmix.readthedocs.io/en/latest/protos.html#algorithm_state.proto
for furhter details on the AlgorithmState object
"""
if return_chains is not None:
from .proto.algorithm_state_pb2 import AlgorithmState

load_dotenv(get_env_file())
BAYESMIX_EXE = os.getenv("BAYESMIX_EXE")
if BAYESMIX_EXE is None:
Expand Down
36 changes: 28 additions & 8 deletions python/tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,67 @@ def test_run_mcmc():
data = get_data()
grid = get_grid()

eval_dens, nclus, clus, best_clus = run_mcmc(
eval_dens, nclus, clus, best_clus, chains = run_mcmc(
"NNIG", "DP", data, GO_PARAMS, DP_PARAMS,
ALGO_PARAMS, grid, return_clusters=False,
return_num_clusters=False, return_best_clus=False)
return_num_clusters=False, return_best_clus=False,
return_chains=False)

assert eval_dens.shape[0] == 5
assert eval_dens.shape[1] == len(grid)
assert nclus is None
assert clus is None
assert best_clus is None
assert chains is None

eval_dens, nclus, clus, best_clus = run_mcmc(
eval_dens, nclus, clus, best_clus, chains = run_mcmc(
"NNIG", "DP", data, GO_PARAMS, DP_PARAMS,
ALGO_PARAMS, None, return_clusters=False,
return_num_clusters=True, return_best_clus=False)
return_num_clusters=True, return_best_clus=False,
return_chains=False)

assert eval_dens is None
assert nclus is not None
assert len(nclus) == 5
assert clus is None
assert best_clus is None
assert chains is None

eval_dens, nclus, clus, best_clus = run_mcmc(
eval_dens, nclus, clus, best_clus, chains = run_mcmc(
"NNIG", "DP", data, GO_PARAMS, DP_PARAMS,
ALGO_PARAMS, None, return_clusters=True,
return_num_clusters=False, return_best_clus=False)
return_num_clusters=False, return_best_clus=False,
return_chains=False)
assert chains is None

assert eval_dens is None
assert nclus is None
assert clus is not None
assert clus.shape[0] == 5
assert clus.shape[1] == len(data)
assert best_clus is None
assert chains is None

eval_dens, nclus, clus, best_clus = run_mcmc(
eval_dens, nclus, clus, best_clus, chains = run_mcmc(
"NNIG", "DP", data, GO_PARAMS, DP_PARAMS,
ALGO_PARAMS, None, return_clusters=False,
return_num_clusters=False, return_best_clus=True)
return_num_clusters=False, return_best_clus=True,
return_chains=False)

assert eval_dens is None
assert nclus is None
assert clus is None
assert best_clus is not None
assert len(best_clus) == len(data)
assert chains is None


eval_dens, nclus, clus, best_clus, chains = run_mcmc(
"NNIG", "DP", data, GO_PARAMS, DP_PARAMS,
ALGO_PARAMS, None, return_clusters=False,
return_num_clusters=False, return_best_clus=False)
assert eval_dens is None
assert nclus is None
assert clus is None
assert best_clus is None
assert len(chains) == 5

0 comments on commit 4cb9727

Please sign in to comment.