From 22bcc6e2719727995f7a1b63d8c12bde7a440de3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Tue, 14 Jan 2025 15:17:15 +0100 Subject: [PATCH] Extend plugins --- src/everest/config/everest_config.py | 14 +++++++++++-- src/everest/config/simulator_config.py | 12 +++++------ tests/everest/test_detached.py | 29 +++++++++++++++++++++++++- 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index 0561da1c6cf..bbd2d29640f 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -27,7 +27,7 @@ ) from ruamel.yaml import YAML, YAMLError -from ert.config import ErtConfig +from ert.config import ErtConfig, QueueConfig from ert.config.parsing import BaseModelWithContextSupport from ert.config.parsing.base_model_context import init_context from ert.plugins import ErtPluginManager @@ -816,7 +816,17 @@ def load_file(config_file: str) -> "EverestConfig": @classmethod def with_plugins(cls, config_dict): - with init_context({"activate_script": ErtPluginManager().activate_script()}): + site_config = ErtConfig.read_site_config() + ert_config: ErtConfig = ErtConfig.with_plugins().from_dict( + config_dict=site_config + ) + queue_config = QueueConfig.from_dict(site_config) + context = { + "activate_script": ErtPluginManager().activate_script(), + "queue_system": queue_config.queue_options, + "pre_installed_jobs": ert_config.installed_forward_model_steps, + } + with init_context(context): return EverestConfig(**config_dict) @staticmethod diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index 0aa44f9fb12..740f3fb8fda 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -1,21 +1,21 @@ from typing import Any from pydantic import ( - BaseModel, Field, NonNegativeInt, PositiveInt, field_validator, model_validator, ) +from pydantic_core.core_schema import ValidationInfo +from ert.config.parsing import BaseModelWithContextSupport from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, SlurmQueueOptions, TorqueQueueOptions, ) -from ert.plugins import ErtPluginManager simulator_example = {"queue_system": {"name": "local", "max_running": 3}} @@ -33,7 +33,7 @@ def check_removed_config(queue_system): ) -class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore +class SimulatorConfig(BaseModelWithContextSupport, extra="forbid"): # type: ignore cores_per_node: PositiveInt | None = Field( default=None, description="""defines the number of CPUs when running @@ -94,11 +94,11 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore @field_validator("queue_system", mode="before") @classmethod - def default_local_queue(cls, v): + def default_local_queue(cls, v, info: ValidationInfo): if v is None: + if info.context: + return info.context[info.field_name] return LocalQueueOptions(max_running=8) - elif "activate_script" not in v and ErtPluginManager().activate_script(): - v["activate_script"] = ErtPluginManager().activate_script() return v @model_validator(mode="before") diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index 89d377fa7a4..f6194816986 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -7,7 +7,7 @@ import requests import everest -from ert.config import ErtConfig +from ert.config import ErtConfig, QueueSystem from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, @@ -325,6 +325,33 @@ def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_c assert config.server.queue_system.activate_script == expected_result +@pytest.mark.parametrize("use_plugin", (True, False)) +@pytest.mark.parametrize( + "queue_options", + [ + {"queue_system": {"name": "slurm"}}, + {}, + ], +) +def test_simulator_queue_system_site_config( + queue_options, use_plugin, monkeypatch, min_config +): + if queue_options: + expected_result = SlurmQueueOptions # User specified + elif use_plugin: + expected_result = LsfQueueOptions # Mock site config + else: + expected_result = LocalQueueOptions # Default value + if use_plugin: + monkeypatch.setattr( + everest.config.everest_config.ErtConfig, + "read_site_config", + MagicMock(return_value={"QUEUE_SYSTEM": QueueSystem.LSF}), + ) + config = EverestConfig.with_plugins({"simulator": queue_options} | min_config) + assert isinstance(config.simulator.queue_system, expected_result) + + @pytest.mark.timeout(5) # Simulation might not finish @pytest.mark.integration_test @pytest.mark.xdist_group(name="starts_everest")