diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index 2c352b9e384..853e3857c9a 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -146,8 +146,7 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore default=OptimizationConfig(), description="Optimizer options", ) - model: ModelConfig | None = Field( - default=ModelConfig(), + model: ModelConfig = Field( description="Configuration of the Everest model", ) @@ -396,13 +395,11 @@ def validate_model_data_file_exists(self) -> Self: # pylint: disable=E0213 @model_validator(mode="after") def validate_maintained_forward_models(self) -> Self: install_data = self.install_data - model = self.model - realizations = model.realizations if model else [0] with InstallDataContext(install_data, self.config_path) as context: - for realization in realizations: + for realization in self.model.realizations: context.add_links_for_realization(realization) - validate_forward_model_configs(self.forward_model, self.install_jobs) + validate_forward_model_configs(self.forward_model, self.install_jobs) return self @model_validator(mode="after") @@ -702,6 +699,7 @@ def with_defaults(cls, **kwargs): "controls": [], "objective_functions": [], "config_path": ".", + "model": {"realizations": [0]}, } return EverestConfig.model_validate({**defaults, **kwargs}) diff --git a/src/everest/config/model_config.py b/src/everest/config/model_config.py index 9fe95affe02..81d69aa83c8 100644 --- a/src/everest/config/model_config.py +++ b/src/everest/config/model_config.py @@ -5,10 +5,10 @@ class ModelConfig(BaseModel, extra="forbid"): # type: ignore realizations: list[NonNegativeInt] = Field( - default_factory=lambda: [], description="""List of realizations to use in optimization ensemble. Typically, this is a list [0, 1, ..., n-1] of all realizations in the ensemble.""", + min_length=1, ) data_file: str | None = Field( default=None, @@ -27,6 +27,9 @@ class ModelConfig(BaseModel, extra="forbid"): # type: ignore @model_validator(mode="before") @classmethod def remove_deprecated(cls, values): + if values is None: + return values + if values.get("report_steps") is not None: ConfigWarning.warn( "report_steps no longer has any effect and can be removed." @@ -34,19 +37,15 @@ def remove_deprecated(cls, values): values.pop("report_steps") return values - @model_validator(mode="before") - @classmethod - def validate_realizations_weights_same_cardinaltiy(cls, values): # pylint: disable=E0213 - weights = values.get("realizations_weights") - reals = values.get("realizations") - + @model_validator(mode="after") + def validate_realizations_weights_same_cardinaltiy(self): # pylint: disable=E0213 + weights = self.realizations_weights if not weights: - return values + return self - if len(weights) != len(reals): + if len(weights) != len(self.realizations): raise ValueError( "Specified realizations_weights must have one" " weight per specified realization in realizations" ) - - return values + return self diff --git a/tests/everest/entry_points/test_everest_entry.py b/tests/everest/entry_points/test_everest_entry.py index d2efd7d81d6..02a00022c8c 100644 --- a/tests/everest/entry_points/test_everest_entry.py +++ b/tests/everest/entry_points/test_everest_entry.py @@ -504,15 +504,13 @@ def test_complete_status_for_normal_run_monitor( return_value={"status": ServerStatus.never_run, "message": None}, ) def test_validate_ert_config_before_starting_everest_server( - server_is_running_mock, server_status_mock, tmpdir, monkeypatch + server_is_running_mock, server_status_mock, copy_math_func_test_data_to_tmp ): - path = tmpdir / "new_folder" - os.makedirs(path) - monkeypatch.chdir(path) - config_file = path / "minimal_config.yml" + config_file = "config_minimal.yml" everest_config = EverestConfig.with_defaults() + everest_config.model.realizations = [] everest_config.dump(config_file) - everest_config.config_path = Path(config_file).absolute() - error = "Expected realizations when analysing data installation source" - with pytest.raises(SystemExit, match=f"Config validation error: {error}"): + everest_config.config_path = Path(config_file) + + with pytest.raises(SystemExit): everest_entry([str(everest_config.config_path)]) diff --git a/tests/everest/test_config_validation.py b/tests/everest/test_config_validation.py index 4b348075399..67acc1d1a3b 100644 --- a/tests/everest/test_config_validation.py +++ b/tests/everest/test_config_validation.py @@ -5,6 +5,7 @@ from argparse import ArgumentParser from pathlib import Path from typing import Any +from unittest.mock import patch import pytest from pydantic import ValidationError @@ -894,8 +895,8 @@ def test_that_missing_required_fields_cause_error(): error_dicts = e.value.errors() # Expect missing error for: - # controls, objective_functions, config_path - assert len(error_dicts) == 3 + # controls, objective_functions, config_path, model + assert len(error_dicts) == 4 config_with_defaults = EverestConfig.with_defaults() config_args = {} @@ -903,6 +904,7 @@ def test_that_missing_required_fields_cause_error(): "controls", "objective_functions", "config_path", + "model", ] for key in required_argnames: @@ -960,24 +962,26 @@ def test_that_non_existing_workflow_jobs_cause_error(): ], ) def test_warning_forward_model_write_objectives(objective, forward_model, warning_msg): - if warning_msg is not None: - with pytest.warns(ConfigWarning, match=warning_msg): - EverestConfig.with_defaults( - objective_functions=[{"name": o} for o in objective], - forward_model=forward_model, - ) - else: - with warnings.catch_warnings(): - warnings.simplefilter("error", category=ConfigWarning) - EverestConfig.with_defaults( - objective_functions=[{"name": o} for o in objective], - forward_model=forward_model, - ) - - -def test_deprecated_keyword(): + # model.realizations is non-empty and therefore this test will run full validation on forward model schema, we don't want that for this test + with patch("everest.config.everest_config.validate_forward_model_configs"): + if warning_msg is not None: + with pytest.warns(ConfigWarning, match=warning_msg): + EverestConfig.with_defaults( + objective_functions=[{"name": o} for o in objective], + forward_model=forward_model, + ) + else: + with warnings.catch_warnings(): + warnings.simplefilter("error", category=ConfigWarning) + EverestConfig.with_defaults( + objective_functions=[{"name": o} for o in objective], + forward_model=forward_model, + ) + + +def test_deprecated_keyword_report_steps(): with pytest.warns(ConfigWarning, match="report_steps .* can be removed"): - ModelConfig(**{"report_steps": []}) + ModelConfig(**{"realizations": [0], "report_steps": []}) def test_load_file_non_existing():