Skip to content

Commit

Permalink
fixup! Simplify simulation arguements
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jun 14, 2024
1 parent 15ed636 commit 040a5f2
Show file tree
Hide file tree
Showing 23 changed files with 81 additions and 171 deletions.
14 changes: 6 additions & 8 deletions src/ert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,20 +387,18 @@ def get_ert_parser(parser: Optional[ArgumentParser] = None) -> ArgumentParser:
alternative_option="--current-ensemble",
dest="current_ensemble",
default="default",
help="Deprecated: This argument is deprecated and will be "
"removed in future versions. Use --current-ensemble instead.",
help="Deprecated: This argument is deprecated and has no effect.",
)
ensemble_smoother_parser.add_argument(
"--current-ensemble",
type=valid_name,
default="default",
help="Name of the ensemble where the results for the experiment "
"using the prior parameters will be stored.",
help="This argument is deprecated and has no effect.",
)
ensemble_smoother_parser.add_argument(
"--target-case",
type=valid_name,
default="posterior",
type=valid_name_format,
default="iter-%d",
action=DeprecatedAction,
alternative_option="--target-ensemble",
dest="target_ensemble",
Expand All @@ -409,8 +407,8 @@ def get_ert_parser(parser: Optional[ArgumentParser] = None) -> ArgumentParser:
)
ensemble_smoother_parser.add_argument(
"--target-ensemble",
type=valid_name,
default="posterior",
type=valid_name_format,
default="iter_%d",
dest="target_ensemble",
help="Name of the ensemble where the results for the "
"updated parameters will be stored.",
Expand Down
2 changes: 1 addition & 1 deletion src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def run_cli(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None)
status_queue,
)
except ValueError as e:
raise ErtCliError(e) from e
raise ErtCliError(f"{args.mode} was not valid, failed with: {e}") from e

if args.port_range is None and model.queue_system == QueueSystem.LOCAL:
args.port_range = range(49152, 51819)
Expand Down
6 changes: 4 additions & 2 deletions src/ert/cli/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,13 @@ def _setup_multiple_data_assimilation(
weights=args.weights,
restart_run=restart_run,
prior_ensemble_id=prior_ensemble,
starting_iteration=storage.get_ensemble(prior_ensemble).iteration + 1
if restart_run
else 0,
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
ensemble_size=config.model_config.num_realizations,
stop_long_running=config.analysis_config.stop_long_running,
experiment_name=args.experiment_name,
starting_iteration=args.starting_iteration,
),
config,
storage,
Expand All @@ -269,7 +271,7 @@ def _setup_iterative_ensemble_smoother(
args, config.model_config.num_realizations
).tolist(),
target_ensemble=_iterative_ensemble_format(config, args),
num_iterations=_num_iterations(config, args),
number_of_iterations=_num_iterations(config, args),
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
ensemble_size=config.model_config.num_realizations,
num_retries_per_iter=config.analysis_config.num_retries_per_iter,
Expand Down
8 changes: 1 addition & 7 deletions src/ert/gui/simulation/multiple_data_assimilation_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class Arguments:
restart_run: bool
prior_ensemble_id: str # UUID not serializable in json
experiment_name: str
starting_iteration: int


class MultipleDataAssimilationPanel(ExperimentConfigPanel):
Expand Down Expand Up @@ -166,9 +165,9 @@ def updateVisualizationOfNormalizedWeights() -> None:
normalized_weights_model.setValue(
", ".join(f"{x:.2f}" for x in normalized_weights)
)
self.weights_valid = True
except ValueError:
normalized_weights_model.setValue("The weights are invalid!")
self.weights_valid = True
else:
normalized_weights_model.setValue("The weights are invalid!")

Expand Down Expand Up @@ -203,11 +202,6 @@ def get_experiment_arguments(self) -> Arguments:
if self._experiment_name_field.text()
else self._experiment_name_field.placeholderText()
),
starting_iteration=(
str(self._ensemble_selector.selected_ensemble.iteration + 1)
if self._restart_box.isChecked()
else 0
),
)

def setWeights(self, weights: Any) -> None:
Expand Down
7 changes: 5 additions & 2 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def __init__(
self._exception: Optional[Exception] = None
self._error_messages: MutableSequence[str] = []
self._queue_config: QueueConfig = queue_config
self._initial_realizations_mask: List[bool] = []
self._initial_realizations_mask: List[bool] = copy.copy(
simulation_arguments.active_realizations
)
self._completed_realizations_mask: List[bool] = []
self.support_restart: bool = True
self.ert_config = config
Expand All @@ -176,9 +178,10 @@ def __init__(
# mapping from iteration number to ensemble id
self._iter_map: Dict[int, str] = {}
self._context_env_keys: List[str] = []
self.random_seed: int = _seed_sequence(self._simulation_arguments.random_seed)
self.random_seed: int = _seed_sequence(simulation_arguments.random_seed)
self.rng = np.random.default_rng(self.random_seed)
self.substitution_list = config.substitution_list

self.run_paths = Runpaths(
jobname_format=config.model_config.jobname_format_string,
runpath_format=config.model_config.runpath_format_string,
Expand Down
6 changes: 2 additions & 4 deletions src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig
) -> RunContext:
self.checkHaveSufficientRealizations(
self._simulation_arguments.active_realizations.count(True),
self.active_realizations.count(True),
self.minimum_required_realizations,
)

Expand All @@ -87,9 +87,7 @@ def run_experiment(
prior_context = RunContext(
ensemble=prior,
runpaths=self.run_paths,
initial_mask=np.array(
self._simulation_arguments.active_realizations, dtype=bool
),
initial_mask=np.array(self.active_realizations, dtype=bool),
iteration=0,
)

Expand Down
4 changes: 1 addition & 3 deletions src/ert/run_models/evaluate_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def run_experiment(
prior_context = RunContext(
ensemble=ensemble,
runpaths=self.run_paths,
initial_mask=np.array(
self._simulation_arguments.active_realizations, dtype=bool
),
initial_mask=np.array(self.active_realizations, dtype=bool),
iteration=ensemble.iteration,
)

Expand Down
11 changes: 5 additions & 6 deletions src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,15 @@ def run_experiment(
prior_context = RunContext(
ensemble=prior,
runpaths=self.run_paths,
initial_mask=np.array(
self._simulation_arguments.active_realizations, dtype=bool
),
initial_mask=np.array(self.active_realizations, dtype=bool),
iteration=prior.iteration,
)
self.prev_successful_realizations = (
prior.get_realization_mask_without_failure().sum()
)
if self.starting_iteration != prior.iteration + 1:
if self.start_iteration != prior.iteration + 1:
raise ValueError(
f"Experiment misconfigured, got starting iteration: {self.starting_iteration},"
f"Experiment misconfigured, got starting iteration: {self.start_iteration},"
f"restart iteration = {prior.iteration + 1}"
)
except (KeyError, ValueError) as err:
Expand Down Expand Up @@ -242,7 +240,8 @@ def parseWeights(weights: str) -> List[float]:
result.append(f)
except ValueError as e:
raise ValueError(f"Warning: cannot parse weight {element}") from e

if not result:
raise ValueError(f"Invalid weights: {weights}")
weights = [weight for weight in result if abs(weight) != 0.0]

Check failure on line 245 in src/ert/run_models/multiple_data_assimilation.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Incompatible types in assignment (expression has type "list[float]", variable has type "str")

length = sum(1.0 / x for x in weights)

Check failure on line 247 in src/ert/run_models/multiple_data_assimilation.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Generator has incompatible item type "float"; expected "bool"

Check failure on line 247 in src/ert/run_models/multiple_data_assimilation.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Unsupported operand types for / ("float" and "str")
Expand Down
15 changes: 2 additions & 13 deletions src/ert/run_models/run_arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional


@dataclass
Expand Down Expand Up @@ -52,17 +52,6 @@ class ESMDARunArguments(SimulationArguments):
@dataclass
class SIESRunArguments(SimulationArguments):
target_ensemble: str
num_iterations: int
num_retries_per_iter: int
ensemble_type: str = "IES"
number_of_iterations: int = 3


RunArgumentsType = Union[
SingleTestRunArguments,
EnsembleExperimentRunArguments,
EvaluateEnsembleRunArguments,
ESRunArguments,
ESMDARunArguments,
SIESRunArguments,
]
ensemble_type: str = "IES"
6 changes: 3 additions & 3 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
if TYPE_CHECKING:
from ert.config.parameter_config import ParameterConfig
from ert.run_models.run_arguments import (
RunArgumentsType,
SimulationArguments,
)
from ert.storage.local_ensemble import LocalEnsemble
from ert.storage.local_storage import LocalStorage
Expand Down Expand Up @@ -100,7 +100,7 @@ def create(
parameters: Optional[List[ParameterConfig]] = None,
responses: Optional[List[ResponseConfig]] = None,
observations: Optional[Dict[str, xr.Dataset]] = None,
simulation_arguments: Optional[RunArgumentsType] = None,
simulation_arguments: Optional[SimulationArguments] = None,
name: Optional[str] = None,
) -> LocalExperiment:
"""
Expand All @@ -120,7 +120,7 @@ def create(
List of response configurations.
observations : dict of str: xr.Dataset, optional
Observations dictionary.
simulation_arguments : RunArgumentsType, optional
simulation_arguments : SimulationArguments, optional
Simulation arguments for the experiment.
name : str, optional
Experiment name. Defaults to current date if None.
Expand Down
11 changes: 6 additions & 5 deletions src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

if TYPE_CHECKING:
from ert.config import ParameterConfig, ResponseConfig
from ert.run_models.run_arguments import RunArgumentsType
from ert.run_models.run_arguments import SimulationArguments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,7 +147,7 @@ def get_experiment(self, uuid: UUID) -> LocalExperiment:

return self._experiments[uuid]

def get_ensemble(self, uuid: UUID) -> LocalEnsemble:
def get_ensemble(self, uuid: Union[UUID, str]) -> LocalEnsemble:
"""
Retrieves an ensemble by UUID.
Expand All @@ -160,7 +160,8 @@ def get_ensemble(self, uuid: UUID) -> LocalEnsemble:
local_ensemble : LocalEnsemble
The ensemble associated with the given UUID.
"""

if isinstance(uuid, str):
uuid = UUID(uuid)
return self._ensembles[uuid]

def get_ensemble_by_name(self, name: str) -> LocalEnsemble:
Expand Down Expand Up @@ -297,7 +298,7 @@ def create_experiment(
parameters: Optional[List[ParameterConfig]] = None,
responses: Optional[List[ResponseConfig]] = None,
observations: Optional[Dict[str, xr.Dataset]] = None,
simulation_arguments: Optional[RunArgumentsType] = None,
simulation_arguments: Optional[SimulationArguments] = None,
name: Optional[str] = None,
) -> LocalExperiment:
"""
Expand All @@ -311,7 +312,7 @@ def create_experiment(
The responses for the experiment.
observations : dict of str to Dataset, optional
The observations for the experiment.
simulation_arguments : RunArgumentsType, optional
simulation_arguments : SimulationArguments, optional
The simulation arguments for the experiment.
name : str, optional
The name of the experiment.
Expand Down
11 changes: 2 additions & 9 deletions tests/integration_tests/analysis/test_adaptive_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,9 @@


def run_cli_ES_with_case(poly_config):
config_name = poly_config.split(".")[0]
prior_sample_name = "prior_sample" + "_" + config_name
posterior_sample_name = "posterior_sample" + "_" + config_name
run_cli(
ENSEMBLE_SMOOTHER_MODE,
"--disable-monitor",
"--current-case",
prior_sample_name,
"--target-case",
posterior_sample_name,
"--realizations",
"1-50",
poly_config,
Expand All @@ -30,8 +23,8 @@ def run_cli_ES_with_case(poly_config):
)
storage_path = ErtConfig.from_file(poly_config).ens_path
with open_storage(storage_path) as storage:
prior_ensemble = storage.get_ensemble_by_name(prior_sample_name)
posterior_ensemble = storage.get_ensemble_by_name(posterior_sample_name)
prior_ensemble = storage.get_ensemble_by_name("iter-0")
posterior_ensemble = storage.get_ensemble_by_name("iter-1")
return prior_ensemble, posterior_ensemble


Expand Down
26 changes: 8 additions & 18 deletions tests/integration_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,15 @@ def test_that_posterior_has_lower_variance_than_prior():
run_cli(
ENSEMBLE_SMOOTHER_MODE,
"--disable-monitor",
"--current-case",
"default",
"--target-case",
"target",
"--realizations",
"1-50",
"poly.ert",
)
facade = LibresFacade.from_config_file("poly.ert")
with open_storage(facade.enspath) as storage:
prior_ensemble = storage.get_ensemble_by_name("default")
prior_ensemble = storage.get_ensemble_by_name("iter-0")
df_default = prior_ensemble.load_all_gen_kw_data()
posterior_ensemble = storage.get_ensemble_by_name("target")
posterior_ensemble = storage.get_ensemble_by_name("iter-1")
df_target = posterior_ensemble.load_all_gen_kw_data()

# The std for the ensemble should decrease
Expand Down Expand Up @@ -142,16 +138,14 @@ def sample_prior(nx, ny):
ENSEMBLE_SMOOTHER_MODE,
"--disable-monitor",
"snake_oil_surface.ert",
"--target-case",
"es_udpate",
)

ert_config = ErtConfig.from_file("snake_oil_surface.ert")

storage = open_storage(ert_config.ens_path)

ens_prior = storage.get_ensemble_by_name("default")
ens_posterior = storage.get_ensemble_by_name("es_udpate")
ens_prior = storage.get_ensemble_by_name("iter-0")
ens_posterior = storage.get_ensemble_by_name("iter-1")

# Check that surfaces defined in INIT_FILES are not changed by ERT
surf_prior = ens_prior.load_parameters("TOP", list(range(ensemble_size)))["values"]
Expand Down Expand Up @@ -182,15 +176,13 @@ def test_update_multiple_param():
ENSEMBLE_SMOOTHER_MODE,
"--disable-monitor",
"snake_oil.ert",
"--target-case",
"posterior",
)

ert_config = ErtConfig.from_file("snake_oil.ert")

storage = open_storage(ert_config.ens_path)
prior_ensemble = storage.get_ensemble_by_name("default")
posterior_ensemble = storage.get_ensemble_by_name("posterior")
prior_ensemble = storage.get_ensemble_by_name("iter-0")
posterior_ensemble = storage.get_ensemble_by_name("iter-1")

prior_array = _all_parameters(prior_ensemble, list(range(10)))
posterior_array = _all_parameters(posterior_ensemble, list(range(10)))
Expand Down Expand Up @@ -480,15 +472,13 @@ def _evaluate(coeffs, x):
ENSEMBLE_SMOOTHER_MODE,
"--disable-monitor",
"poly.ert",
"--target-case",
"posterior",
)

ert_config = ErtConfig.from_file("poly.ert")

with open_storage(ert_config.ens_path) as storage:
prior = storage.get_ensemble_by_name("default")
posterior = storage.get_ensemble_by_name("posterior")
prior = storage.get_ensemble_by_name("iter-0")
posterior = storage.get_ensemble_by_name("iter-1")

assert all(
posterior.get_ensemble_state()[idx]
Expand Down
Loading

0 comments on commit 040a5f2

Please sign in to comment.