From 958bef4fb46f756030c55050eef1b1004bd34480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Tue, 14 Jan 2025 15:17:15 +0100 Subject: [PATCH] Extend plugins --- src/everest/config/everest_config.py | 17 ++++++++++++--- src/everest/config/simulator_config.py | 12 +++++------ tests/everest/test_detached.py | 29 +++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index 0561da1c6cf..9539408cfd2 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -25,9 +25,10 @@ field_validator, model_validator, ) +from pydantic_core.core_schema import ValidationInfo from ruamel.yaml import YAML, YAMLError -from ert.config import ErtConfig +from ert.config import ErtConfig, QueueConfig from ert.config.parsing import BaseModelWithContextSupport from ert.config.parsing.base_model_context import init_context from ert.plugins import ErtPluginManager @@ -274,7 +275,7 @@ def validate_queue_system(self) -> Self: # pylint: disable=E0213 return self @model_validator(mode="after") - def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable=E0213 + def validate_forward_model_job_name_installed(self, info: ValidationInfo) -> Self: # pylint: disable=E0213 install_jobs = self.install_jobs forward_model_jobs = self.forward_model if install_jobs is None: @@ -816,7 +817,17 @@ def load_file(config_file: str) -> "EverestConfig": @classmethod def with_plugins(cls, config_dict): - with init_context({"activate_script": ErtPluginManager().activate_script()}): + site_config = ErtConfig.read_site_config() + ert_config: ErtConfig = ErtConfig.with_plugins().from_dict( + config_dict=site_config + ) + queue_config = QueueConfig.from_dict(site_config) + context = { + "activate_script": ErtPluginManager().activate_script(), + "queue_system": queue_config.queue_options, + "pre_installed_jobs": ert_config.installed_forward_model_steps, + } + with init_context(context): return EverestConfig(**config_dict) @staticmethod diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index 0aa44f9fb12..740f3fb8fda 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -1,21 +1,21 @@ from typing import Any from pydantic import ( - BaseModel, Field, NonNegativeInt, PositiveInt, field_validator, model_validator, ) +from pydantic_core.core_schema import ValidationInfo +from ert.config.parsing import BaseModelWithContextSupport from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, SlurmQueueOptions, TorqueQueueOptions, ) -from ert.plugins import ErtPluginManager simulator_example = {"queue_system": {"name": "local", "max_running": 3}} @@ -33,7 +33,7 @@ def check_removed_config(queue_system): ) -class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore +class SimulatorConfig(BaseModelWithContextSupport, extra="forbid"): # type: ignore cores_per_node: PositiveInt | None = Field( default=None, description="""defines the number of CPUs when running @@ -94,11 +94,11 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore @field_validator("queue_system", mode="before") @classmethod - def default_local_queue(cls, v): + def default_local_queue(cls, v, info: ValidationInfo): if v is None: + if info.context: + return info.context[info.field_name] return LocalQueueOptions(max_running=8) - elif "activate_script" not in v and ErtPluginManager().activate_script(): - v["activate_script"] = ErtPluginManager().activate_script() return v @model_validator(mode="before") diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index 89d377fa7a4..f6194816986 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -7,7 +7,7 @@ import requests import everest -from ert.config import ErtConfig +from ert.config import ErtConfig, QueueSystem from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, @@ -325,6 +325,33 @@ def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_c assert config.server.queue_system.activate_script == expected_result +@pytest.mark.parametrize("use_plugin", (True, False)) +@pytest.mark.parametrize( + "queue_options", + [ + {"queue_system": {"name": "slurm"}}, + {}, + ], +) +def test_simulator_queue_system_site_config( + queue_options, use_plugin, monkeypatch, min_config +): + if queue_options: + expected_result = SlurmQueueOptions # User specified + elif use_plugin: + expected_result = LsfQueueOptions # Mock site config + else: + expected_result = LocalQueueOptions # Default value + if use_plugin: + monkeypatch.setattr( + everest.config.everest_config.ErtConfig, + "read_site_config", + MagicMock(return_value={"QUEUE_SYSTEM": QueueSystem.LSF}), + ) + config = EverestConfig.with_plugins({"simulator": queue_options} | min_config) + assert isinstance(config.simulator.queue_system, expected_result) + + @pytest.mark.timeout(5) # Simulation might not finish @pytest.mark.integration_test @pytest.mark.xdist_group(name="starts_everest")