Skip to content

Commit

Permalink
test if these needed
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 5, 2024
1 parent 25dde66 commit 91e1877
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
1 change: 0 additions & 1 deletion cobaya/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def run(info_or_yaml_or_file: Union[InputDict, str, os.PathLike],
# as early as possible, e.g. to check if resuming possible or `force` needed.
if no_mpi or test:
mpi.set_mpi_disabled()
mpi.sync_processes()
with mpi.ProcessState("run"):
flags = {packages_path_input: packages_path, "debug": debug,
"stop_at_error": stop_at_error, "resume": resume, "force": force,
Expand Down
16 changes: 16 additions & 0 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Any, Optional, Union, Type, TypedDict, Literal, Mapping, \
Callable, Sequence, Iterable
from types import MappingProxyType
import contextlib
import typing
import numbers
import numpy as np
Expand Down Expand Up @@ -112,6 +113,21 @@ class InputDict(ModelDict, total=False):
enforce_type_checking = None


@contextlib.contextmanager
def type_checking(value: bool):
"""
Context manager to temporarily set typing.enforce_type_checking to a specific value.
Restores the original value when exiting the context.
"""
global enforce_type_checking
original_value = enforce_type_checking
enforce_type_checking = value
try:
yield
finally:
enforce_type_checking = original_value


def validate_type(expected_type: type, value: Any, path: str = ''):
"""
Checks for soft compatibility of a value with a type.
Expand Down
5 changes: 3 additions & 2 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from cobaya.tools import KL_norm
from cobaya.yaml import yaml_load
from .common_sampler import body_of_sampler_test, body_of_test_speeds
from cobaya.typing import type_checking

pytestmark = pytest.mark.mpi

Expand Down Expand Up @@ -39,7 +40,7 @@ def test_mcmc(tmpdir, temperature, do_plots, packages_path=None):

def check_gaussian(sampler_instance):
if not len(sampler_instance.collection) or \
not len(sampler_instance.collection[int(sampler_instance.n() / 2):]):
not len(sampler_instance.collection[int(sampler_instance.n() / 2):]):
return
proposer = KL_norm(
S1=sampler_instance.model.likelihood["gaussian_mixture"].covs[0],
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_mcmc_sync():
logger.info('Test error synchronization')
if mpi.rank() == 0:
info['sampler']['mcmc'] = {'max_samples': 'bad_val'}
with NoLogging(logging.ERROR), pytest.raises(TypeError):
with NoLogging(logging.ERROR), pytest.raises(TypeError), type_checking(False):
run(info)
else:
with pytest.raises(mpi.OtherProcessError):
Expand Down

0 comments on commit 91e1877

Please sign in to comment.