Skip to content

Commit

Permalink
Extend plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jan 14, 2025
1 parent 2e40b3f commit 22bcc6e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
14 changes: 12 additions & 2 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from ruamel.yaml import YAML, YAMLError

from ert.config import ErtConfig
from ert.config import ErtConfig, QueueConfig
from ert.config.parsing import BaseModelWithContextSupport
from ert.config.parsing.base_model_context import init_context
from ert.plugins import ErtPluginManager
Expand Down Expand Up @@ -816,7 +816,17 @@ def load_file(config_file: str) -> "EverestConfig":

@classmethod
def with_plugins(cls, config_dict):
with init_context({"activate_script": ErtPluginManager().activate_script()}):
site_config = ErtConfig.read_site_config()
ert_config: ErtConfig = ErtConfig.with_plugins().from_dict(
config_dict=site_config
)
queue_config = QueueConfig.from_dict(site_config)
context = {
"activate_script": ErtPluginManager().activate_script(),
"queue_system": queue_config.queue_options,
"pre_installed_jobs": ert_config.installed_forward_model_steps,
}
with init_context(context):
return EverestConfig(**config_dict)

@staticmethod
Expand Down
12 changes: 6 additions & 6 deletions src/everest/config/simulator_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from typing import Any

from pydantic import (
BaseModel,
Field,
NonNegativeInt,
PositiveInt,
field_validator,
model_validator,
)
from pydantic_core.core_schema import ValidationInfo

from ert.config.parsing import BaseModelWithContextSupport
from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager

simulator_example = {"queue_system": {"name": "local", "max_running": 3}}

Expand All @@ -33,7 +33,7 @@ def check_removed_config(queue_system):
)


class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore
class SimulatorConfig(BaseModelWithContextSupport, extra="forbid"): # type: ignore
cores_per_node: PositiveInt | None = Field(
default=None,
description="""defines the number of CPUs when running
Expand Down Expand Up @@ -94,11 +94,11 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore

@field_validator("queue_system", mode="before")
@classmethod
def default_local_queue(cls, v):
def default_local_queue(cls, v, info: ValidationInfo):
if v is None:
if info.context:
return info.context[info.field_name]
return LocalQueueOptions(max_running=8)
elif "activate_script" not in v and ErtPluginManager().activate_script():
v["activate_script"] = ErtPluginManager().activate_script()
return v

@model_validator(mode="before")
Expand Down
29 changes: 28 additions & 1 deletion tests/everest/test_detached.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import requests

import everest
from ert.config import ErtConfig
from ert.config import ErtConfig, QueueSystem
from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
Expand Down Expand Up @@ -325,6 +325,33 @@ def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_c
assert config.server.queue_system.activate_script == expected_result


@pytest.mark.parametrize("use_plugin", (True, False))
@pytest.mark.parametrize(
"queue_options",
[
{"queue_system": {"name": "slurm"}},
{},
],
)
def test_simulator_queue_system_site_config(
queue_options, use_plugin, monkeypatch, min_config
):
if queue_options:
expected_result = SlurmQueueOptions # User specified
elif use_plugin:
expected_result = LsfQueueOptions # Mock site config
else:
expected_result = LocalQueueOptions # Default value
if use_plugin:
monkeypatch.setattr(
everest.config.everest_config.ErtConfig,
"read_site_config",
MagicMock(return_value={"QUEUE_SYSTEM": QueueSystem.LSF}),
)
config = EverestConfig.with_plugins({"simulator": queue_options} | min_config)
assert isinstance(config.simulator.queue_system, expected_result)


@pytest.mark.timeout(5) # Simulation might not finish
@pytest.mark.integration_test
@pytest.mark.xdist_group(name="starts_everest")
Expand Down

0 comments on commit 22bcc6e

Please sign in to comment.