Skip to content

Commit

Permalink
Non optional everest model and nonzero realizations (#9577)
Browse files Browse the repository at this point in the history
For EverestConfig require Model and realizations.len() > 0
  • Loading branch information
StephanDeHoop authored Jan 7, 2025
1 parent 00b015f commit eb9b04a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 44 deletions.
10 changes: 4 additions & 6 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore
default=OptimizationConfig(),
description="Optimizer options",
)
model: ModelConfig | None = Field(
default=ModelConfig(),
model: ModelConfig = Field(
description="Configuration of the Everest model",
)

Expand Down Expand Up @@ -396,13 +395,11 @@ def validate_model_data_file_exists(self) -> Self: # pylint: disable=E0213
@model_validator(mode="after")
def validate_maintained_forward_models(self) -> Self:
install_data = self.install_data
model = self.model
realizations = model.realizations if model else [0]

with InstallDataContext(install_data, self.config_path) as context:
for realization in realizations:
for realization in self.model.realizations:
context.add_links_for_realization(realization)
validate_forward_model_configs(self.forward_model, self.install_jobs)
validate_forward_model_configs(self.forward_model, self.install_jobs)
return self

@model_validator(mode="after")
Expand Down Expand Up @@ -702,6 +699,7 @@ def with_defaults(cls, **kwargs):
"controls": [],
"objective_functions": [],
"config_path": ".",
"model": {"realizations": [0]},
}

return EverestConfig.model_validate({**defaults, **kwargs})
Expand Down
21 changes: 10 additions & 11 deletions src/everest/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

class ModelConfig(BaseModel, extra="forbid"): # type: ignore
realizations: list[NonNegativeInt] = Field(
default_factory=lambda: [],
description="""List of realizations to use in optimization ensemble.
Typically, this is a list [0, 1, ..., n-1] of all realizations in the ensemble.""",
min_length=1,
)
data_file: str | None = Field(
default=None,
Expand All @@ -27,26 +27,25 @@ class ModelConfig(BaseModel, extra="forbid"): # type: ignore
@model_validator(mode="before")
@classmethod
def remove_deprecated(cls, values):
if values is None:
return values

if values.get("report_steps") is not None:
ConfigWarning.warn(
"report_steps no longer has any effect and can be removed."
)
values.pop("report_steps")
return values

@model_validator(mode="before")
@classmethod
def validate_realizations_weights_same_cardinaltiy(cls, values): # pylint: disable=E0213
weights = values.get("realizations_weights")
reals = values.get("realizations")

@model_validator(mode="after")
def validate_realizations_weights_same_cardinaltiy(self): # pylint: disable=E0213
weights = self.realizations_weights
if not weights:
return values
return self

if len(weights) != len(reals):
if len(weights) != len(self.realizations):
raise ValueError(
"Specified realizations_weights must have one"
" weight per specified realization in realizations"
)

return values
return self
14 changes: 6 additions & 8 deletions tests/everest/entry_points/test_everest_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,15 +504,13 @@ def test_complete_status_for_normal_run_monitor(
return_value={"status": ServerStatus.never_run, "message": None},
)
def test_validate_ert_config_before_starting_everest_server(
server_is_running_mock, server_status_mock, tmpdir, monkeypatch
server_is_running_mock, server_status_mock, copy_math_func_test_data_to_tmp
):
path = tmpdir / "new_folder"
os.makedirs(path)
monkeypatch.chdir(path)
config_file = path / "minimal_config.yml"
config_file = "config_minimal.yml"
everest_config = EverestConfig.with_defaults()
everest_config.model.realizations = []
everest_config.dump(config_file)
everest_config.config_path = Path(config_file).absolute()
error = "Expected realizations when analysing data installation source"
with pytest.raises(SystemExit, match=f"Config validation error: {error}"):
everest_config.config_path = Path(config_file)

with pytest.raises(SystemExit):
everest_entry([str(everest_config.config_path)])
42 changes: 23 additions & 19 deletions tests/everest/test_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from argparse import ArgumentParser
from pathlib import Path
from typing import Any
from unittest.mock import patch

import pytest
from pydantic import ValidationError
Expand Down Expand Up @@ -894,15 +895,16 @@ def test_that_missing_required_fields_cause_error():
error_dicts = e.value.errors()

# Expect missing error for:
# controls, objective_functions, config_path
assert len(error_dicts) == 3
# controls, objective_functions, config_path, model
assert len(error_dicts) == 4

config_with_defaults = EverestConfig.with_defaults()
config_args = {}
required_argnames = [
"controls",
"objective_functions",
"config_path",
"model",
]

for key in required_argnames:
Expand Down Expand Up @@ -960,24 +962,26 @@ def test_that_non_existing_workflow_jobs_cause_error():
],
)
def test_warning_forward_model_write_objectives(objective, forward_model, warning_msg):
if warning_msg is not None:
with pytest.warns(ConfigWarning, match=warning_msg):
EverestConfig.with_defaults(
objective_functions=[{"name": o} for o in objective],
forward_model=forward_model,
)
else:
with warnings.catch_warnings():
warnings.simplefilter("error", category=ConfigWarning)
EverestConfig.with_defaults(
objective_functions=[{"name": o} for o in objective],
forward_model=forward_model,
)


def test_deprecated_keyword():
# model.realizations is non-empty and therefore this test will run full validation on forward model schema, we don't want that for this test
with patch("everest.config.everest_config.validate_forward_model_configs"):
if warning_msg is not None:
with pytest.warns(ConfigWarning, match=warning_msg):
EverestConfig.with_defaults(
objective_functions=[{"name": o} for o in objective],
forward_model=forward_model,
)
else:
with warnings.catch_warnings():
warnings.simplefilter("error", category=ConfigWarning)
EverestConfig.with_defaults(
objective_functions=[{"name": o} for o in objective],
forward_model=forward_model,
)


def test_deprecated_keyword_report_steps():
with pytest.warns(ConfigWarning, match="report_steps .* can be removed"):
ModelConfig(**{"report_steps": []})
ModelConfig(**{"realizations": [0], "report_steps": []})


def test_load_file_non_existing():
Expand Down

0 comments on commit eb9b04a

Please sign in to comment.