diff --git a/docs/getting_started/howto/plugin_system.rst b/docs/getting_started/howto/plugin_system.rst index 36ef395ab18..d158550317e 100644 --- a/docs/getting_started/howto/plugin_system.rst +++ b/docs/getting_started/howto/plugin_system.rst @@ -120,6 +120,8 @@ Implement the hook specification as follows to register the workflow job ``CSV_E def installable_workflow_jobs() -> Dict[str, str]: return {"": "CSV_EXPORT"} +.. _legacy_ert_workflow_jobs: + 2. **Using the legacy_ertscript_workflow hook** The second approach does not require creating a workflow job configuration file up-front, diff --git a/docs/reference/workflows/workflows.rst b/docs/reference/workflows/workflows.rst index 0ee40bc08b1..8e6c0a5bf22 100644 --- a/docs/reference/workflows/workflows.rst +++ b/docs/reference/workflows/workflows.rst @@ -10,3 +10,90 @@ and go through all the realizations in one loop, forward model jobs run in paral The executable invoked by the workflow job can be an executable you have written yourself - in any language, or it can be an existing Linux command like e.g. :code:`cp` or :code:`mv`. + +Internal workflow jobs +====================== + +.. warning:: + Internal workflow jobs are under development and the API is subject to changes + +Internal workflow jobs is a way to call custom python scripts as workflows. In order +to use this, create a class which inherits from `ErtScript`: + +.. code-block:: python + + from ert import ErtScript + + class MyJob(ErtScript): + def run(self): + print("Hello World") + +ERT will initialize this class and call the `run` function when the workflow is called, +either through hooks, or through the gui/cli. + +The `run` function can be called with a number of arguments, depending on the context the workflow +is called. There are three distinct ways to call the `run` function: + +1. If the `run` function is using `*args` in the `run` function, only the arguments from the user +configuration is passed to the workflow: + +.. code-block:: python + + class MyJob(ErtScript): + def run(self, *args): + print(f"Provided user arguments: {args}") + +1. If the `run` function is using `positional arguments in the `run` function. In this case no fixtures +will be added: + +.. warning:: + This is not recommended, you are adviced to use either option 1 or option 3 + +.. code-block:: python + + class MyJob(ErtScript): + def run(self, my_arg_1, my_arg_2): + print(f"Provided user arguments: {my_arg_1, my_arg_2}") + +.. note:: + The name of the argument is not required to be `args`, that is just convention. + +3. There are a number of specially named arguments the user can call which gives access to internal +state of the experiment that is running: + +.. glossary:: + + ert_config + This gives access to the full configuration of the running experiment + + storage + This gives access to the storage of the current session + + ensemble + This gives access to the current ensemble, making it possible to load responses and parameters + + workflow_args + This gives access to the arguments from the user configuration file + +.. note:: + The current ensemble will depend on the context. For hooked workflows the ensemble will be: + `PRE_SIMULATION`: parameters and no reponses in ensemble + `POST_SIMULATION`: parameters and responses in ensemble + `PRE_FIRST_UPDATE`/`PRE_UPDATE`: parameters and responses in ensemble + `POST_UPDATE`: parameters and responses in ensemble + The ensemble will switch after the `POST_UPDATE` hook, and will move from prior -> posterior + +.. code-block:: python + + + class MyJob(ErtScript): + def run( + self, + workflow_args: List, + ert_config: ErtConfig, + ensemble: Ensemble, + storage: Storage, + ): + print(f"Provided user arguments: {workflow_args}") + +For how to load internal workflow jobs into ERT, see: :ref:`installing workflows ` diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index 55c949702dc..fd34c0f92ae 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -12,7 +12,6 @@ from ert.cli.monitor import Monitor from ert.cli.workflow import execute_workflow from ert.config import ErtConfig, QueueSystem -from ert.enkf_main import EnKFMain from ert.ensemble_evaluator import EvaluatorServerConfig from ert.mode_definitions import ( ENSEMBLE_EXPERIMENT_MODE, @@ -57,7 +56,6 @@ def run_cli(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None) for fm_step in ert_config.forward_model_steps: logger.info("Config contains forward model step %s", fm_step.name) - ert = EnKFMain(ert_config) if not ert_config.observation_keys and args.mode not in [ ENSEMBLE_EXPERIMENT_MODE, TEST_RUN_MODE, @@ -82,7 +80,7 @@ def run_cli(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None) storage = open_storage(ert_config.ens_path, "w") if args.mode == WORKFLOW_MODE: - execute_workflow(ert, storage, args.name) + execute_workflow(ert_config, storage, args.name) return status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue() diff --git a/src/ert/cli/workflow.py b/src/ert/cli/workflow.py index 016364bf3e5..a879b810571 100644 --- a/src/ert/cli/workflow.py +++ b/src/ert/cli/workflow.py @@ -6,19 +6,21 @@ from ert.job_queue import WorkflowRunner if TYPE_CHECKING: - from ert.enkf_main import EnKFMain + from ert.config import ErtConfig from ert.storage import Storage -def execute_workflow(ert: EnKFMain, storage: Storage, workflow_name: str) -> None: +def execute_workflow( + ert_config: ErtConfig, storage: Storage, workflow_name: str +) -> None: logger = logging.getLogger(__name__) try: - workflow = ert.ert_config.workflows[workflow_name] + workflow = ert_config.workflows[workflow_name] except KeyError: msg = "Workflow {} is not in the list of available workflows" logger.error(msg.format(workflow_name)) return - runner = WorkflowRunner(workflow, ert, storage) + runner = WorkflowRunner(workflow, storage, ert_config=ert_config) runner.run_blocking() if not all(v["completed"] for v in runner.workflowReport().values()): logger.error(f"Workflow {workflow_name} failed!") diff --git a/src/ert/config/ert_plugin.py b/src/ert/config/ert_plugin.py index 78d750e8350..0a32cd7b565 100644 --- a/src/ert/config/ert_plugin.py +++ b/src/ert/config/ert_plugin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from typing import Any, List @@ -9,7 +11,7 @@ class CancelPluginException(Exception): class ErtPlugin(ErtScript, ABC): - def getArguments(self, parent: Any = None) -> List[Any]: # noqa: PLR6301 + def getArguments(self, args: List[Any]) -> List[Any]: # noqa: PLR6301 return [] def getName(self) -> str: diff --git a/src/ert/config/ert_script.py b/src/ert/config/ert_script.py index 2a569a6e357..5cffa60695c 100644 --- a/src/ert/config/ert_script.py +++ b/src/ert/config/ert_script.py @@ -5,14 +5,18 @@ import logging import sys import traceback +import warnings from abc import abstractmethod -from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Type +from types import MappingProxyType, ModuleType +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union + +from typing_extensions import deprecated if TYPE_CHECKING: - from ert.enkf_main import EnKFMain + from ert.config import ErtConfig from ert.storage import Ensemble, Storage + Fixtures = Union[ErtConfig, Ensemble, Storage] logger = logging.getLogger(__name__) @@ -28,19 +32,17 @@ class ErtScript: def __init__( self, - ert: EnKFMain, - storage: Storage, - ensemble: Optional[Ensemble] = None, ) -> None: - self.__ert = ert - self.__storage = storage - self.__ensemble = ensemble - self.__is_cancelled = False self.__failed = False self._stdoutdata = "" self._stderrdata = "" + # Deprecated: + self._ert = None + self._ensemble = None + self._storage = None + @abstractmethod def run(self, *arg: Any, **kwarg: Any) -> Any: """ @@ -68,21 +70,30 @@ def stderrdata(self) -> str: self._stderrdata = self._stderrdata.decode() return self._stderrdata - def ert(self) -> "EnKFMain": + @deprecated("Use fixtures to the run function instead") + def ert(self) -> Optional[ErtConfig]: logger.info(f"Accessing EnKFMain from workflow: {self.__class__.__name__}") - return self.__ert - - @property - def storage(self) -> Storage: - return self.__storage + return self._ert @property def ensemble(self) -> Optional[Ensemble]: - return self.__ensemble + warnings.warn( + "The ensemble property is deprecated, use the fixture to the run function instead", + DeprecationWarning, + stacklevel=1, + ) + logger.info(f"Accessing ensemble from workflow: {self.__class__.__name__}") + return self._ensemble - @ensemble.setter - def ensemble(self, ensemble: Ensemble) -> None: - self.__ensemble = ensemble + @property + def storage(self) -> Optional[Storage]: + warnings.warn( + "The storage property is deprecated, use the fixture to the run function instead", + DeprecationWarning, + stacklevel=1, + ) + logger.info(f"Accessing storage from workflow: {self.__class__.__name__}") + return self._storage def isCancelled(self) -> bool: return self.__is_cancelled @@ -102,7 +113,9 @@ def initializeAndRun( self, argument_types: List[Type[Any]], argument_values: List[str], + fixtures: Optional[Dict[str, Any]] = None, ) -> Any: + fixtures = {} if fixtures is None else fixtures arguments = [] for index, arg_value in enumerate(argument_values): arg_type = argument_types[index] if index < len(argument_types) else str @@ -111,8 +124,24 @@ def initializeAndRun( arguments.append(arg_type(arg_value)) # type: ignore else: arguments.append(None) - + fixtures["workflow_args"] = arguments try: + func_args = inspect.signature(self.run).parameters + # If the user has specified *args, we skip injecting fixtures, and just + # pass the user configured arguments + if not any([p.kind == p.VAR_POSITIONAL for p in func_args.values()]): + try: + arguments = self.insert_fixtures(func_args, fixtures) + except ValueError as e: + # This is here for backwards compatibility, the user does not have *argv + # but positional arguments. Can not be mixed with using fixtures. + logger.warning( + f"Mixture of fixtures and positional arguments, err: {e}" + ) + # Part of deprecation + self._ert = fixtures.get("ert_config") + self._ensemble = fixtures.get("ensemble") + self._storage = fixtures.get("storage") return self.run(*arguments) except AttributeError as e: error_msg = str(e) @@ -137,6 +166,25 @@ def initializeAndRun( # Need to have unique modules in case of identical object naming in scripts __module_count = 0 + def insert_fixtures( + self, + func_args: MappingProxyType[str, inspect.Parameter], + fixtures: Dict[str, Fixtures], + ) -> List[Any]: + arguments = [] + errors = [] + for val in func_args: + if val in fixtures: + arguments.append(fixtures[val]) + else: + errors.append(val) + if errors: + raise ValueError( + f"Plugin: {self.__class__.__name__} misconfigured, arguments: {errors} " + f"not found in fixtures: {list(fixtures)}" + ) + return arguments + def output_stack_trace(self, error: str = "") -> None: stack_trace = error or "".join(traceback.format_exception(*sys.exc_info())) sys.stderr.write( @@ -150,7 +198,7 @@ def output_stack_trace(self, error: str = "") -> None: @staticmethod def loadScriptFromFile( path: str, - ) -> Callable[["EnKFMain", "Storage"], "ErtScript"]: + ) -> Callable[[], "ErtScript"]: module_name = f"ErtScriptModule_{ErtScript.__module_count}" ErtScript.__module_count += 1 @@ -171,7 +219,7 @@ def loadScriptFromFile( @staticmethod def __findErtScriptImplementations( module: ModuleType, - ) -> Callable[["EnKFMain", "Storage"], "ErtScript"]: + ) -> Callable[[], "ErtScript"]: result = [] for _, member in inspect.getmembers( module, diff --git a/src/ert/config/external_ert_script.py b/src/ert/config/external_ert_script.py index 24f8321c6da..4138e5435a5 100644 --- a/src/ert/config/external_ert_script.py +++ b/src/ert/config/external_ert_script.py @@ -3,18 +3,14 @@ import codecs import sys from subprocess import PIPE, Popen -from typing import TYPE_CHECKING, Any, Optional +from typing import Any, Optional from .ert_script import ErtScript -if TYPE_CHECKING: - from ert.enkf_main import EnKFMain - from ert.storage import Storage - class ExternalErtScript(ErtScript): - def __init__(self, ert: EnKFMain, storage: Storage, executable: str): - super().__init__(ert, storage, None) + def __init__(self, executable: str): + super().__init__() self.__executable = executable self.__job: Optional[Popen[bytes]] = None diff --git a/src/ert/dark_storage/__init__.py b/src/ert/dark_storage/__init__.py index 4c850745a3e..e5942e0d996 100644 --- a/src/ert/dark_storage/__init__.py +++ b/src/ert/dark_storage/__init__.py @@ -1,4 +1,3 @@ """ -Dark Storage is an API towards data provided by the legacy EnKFMain object and -the `storage/` directory. +Dark Storage is an API towards data provided the `storage/` directory. """ diff --git a/src/ert/enkf_main.py b/src/ert/enkf_main.py index 4003e37c606..5301ae5d157 100644 --- a/src/ert/enkf_main.py +++ b/src/ert/enkf_main.py @@ -12,7 +12,6 @@ from numpy.random import SeedSequence from .config import ParameterConfig -from .job_queue import WorkflowRunner from .run_context import RunContext from .runpaths import Runpaths from .substitution_list import SubstitutionList @@ -20,8 +19,8 @@ if TYPE_CHECKING: import numpy.typing as npt - from .config import ErtConfig, HookRuntime - from .storage import Ensemble, Storage + from .config import ErtConfig + from .storage import Ensemble logger = logging.getLogger(__name__) @@ -125,24 +124,6 @@ def _seed_sequence(seed: Optional[int]) -> int: return int_seed -class EnKFMain: - def __init__(self, config: "ErtConfig", read_only: bool = False) -> None: - self.ert_config = config - self.update_configuration = None - - def __repr__(self) -> str: - return f"EnKFMain(size: {self.ert_config.model_config.num_realizations}, config: {self.ert_config})" - - def runWorkflows( - self, - runtime: HookRuntime, - storage: Optional[Storage] = None, - ensemble: Optional[Ensemble] = None, - ) -> None: - for workflow in self.ert_config.hooked_workflows[runtime]: - WorkflowRunner(workflow, self, storage, ensemble).run_blocking() - - def sample_prior( ensemble: Ensemble, active_realizations: Iterable[int], diff --git a/src/ert/gui/ertwidgets/summarypanel.py b/src/ert/gui/ertwidgets/summarypanel.py index fd1551f3b89..e0b6b60ea20 100644 --- a/src/ert/gui/ertwidgets/summarypanel.py +++ b/src/ert/gui/ertwidgets/summarypanel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, Any, List, Tuple from qtpy.QtCore import Qt @@ -14,7 +16,7 @@ from ert.gui.ertwidgets.models.ertsummary import ErtSummary if TYPE_CHECKING: - from ert.enkf_main import EnKFMain + from ert.config import ErtConfig class SummaryTemplate: @@ -55,8 +57,8 @@ def getText(self) -> str: class SummaryPanel(QFrame): - def __init__(self, ert: "EnKFMain"): - self.ert = ert + def __init__(self, config: ErtConfig): + self.config = config QFrame.__init__(self) self.setMinimumWidth(250) @@ -77,7 +79,7 @@ def __init__(self, ert: "EnKFMain"): self.updateSummary() def updateSummary(self) -> None: - summary = ErtSummary(self.ert.ert_config) + summary = ErtSummary(self.config) forward_model_list = summary.getForwardModels() plural_s = "" diff --git a/src/ert/gui/main.py b/src/ert/gui/main.py index b42b4a88a68..49aba79148f 100755 --- a/src/ert/gui/main.py +++ b/src/ert/gui/main.py @@ -18,7 +18,6 @@ from qtpy.QtWidgets import QApplication, QWidget from ert.config import ConfigValidationError, ConfigWarning, ErtConfig -from ert.enkf_main import EnKFMain from ert.gui.ertwidgets import SummaryPanel from ert.gui.main_window import ErtMainWindow from ert.gui.simulation import ExperimentPanel @@ -111,7 +110,6 @@ def _start_initial_gui_window( ).from_file(args.config) local_storage_set_ert_config(ert_config) - ert = EnKFMain(ert_config) except ConfigValidationError as error: config_warnings = [ cast(ConfigWarning, w.message).info @@ -166,7 +164,9 @@ def _start_initial_gui_window( for msg in config_warnings: logger.info("Warning shown in gui '%s'", msg) storage = open_storage(ert_config.ens_path, mode="w") - _main_window = _setup_main_window(ert, args, log_handler, storage, plugin_manager) + _main_window = _setup_main_window( + ert_config, args, log_handler, storage, plugin_manager + ) if deprecations or config_warnings: def continue_action() -> None: @@ -229,41 +229,44 @@ def _clicked_about_button(about_dialog: QWidget) -> None: def _setup_main_window( - ert: EnKFMain, + config: ErtConfig, args: Namespace, log_handler: GUILogHandler, storage: Storage, plugin_manager: Optional[ErtPluginManager] = None, ) -> ErtMainWindow: # window reference must be kept until app.exec returns: - facade = LibresFacade(ert) + facade = LibresFacade(config) config_file = args.config - config = ert.ert_config window = ErtMainWindow(config_file, plugin_manager) window.notifier.set_storage(storage) - window.setWidget(ExperimentPanel(ert, window.notifier, config_file)) + window.setWidget( + ExperimentPanel( + config, window.notifier, config_file, facade.get_ensemble_size() + ) + ) + plugin_handler = PluginHandler( - ert, window.notifier, - [wfj for wfj in ert.ert_config.workflow_jobs.values() if wfj.is_plugin()], + [wfj for wfj in config.workflow_jobs.values() if wfj.is_plugin()], window, ) window.addDock( "Configuration summary", - SummaryPanel(ert), + SummaryPanel(config), area=Qt.DockWidgetArea.BottomDockWidgetArea, ) window.addTool(PlotTool(config_file, window)) - window.addTool(ExportTool(ert, window.notifier)) - window.addTool(WorkflowsTool(ert, window.notifier)) + window.addTool(ExportTool(config, window.notifier)) + window.addTool(WorkflowsTool(config, window.notifier)) window.addTool( ManageExperimentsTool( config, window.notifier, config.model_config.num_realizations ) ) - window.addTool(PluginsTool(plugin_handler, window.notifier)) - window.addTool(RunAnalysisTool(ert, window.notifier)) + window.addTool(PluginsTool(plugin_handler, window.notifier, config)) + window.addTool(RunAnalysisTool(config, window.notifier)) window.addTool(LoadResultsTool(facade, window.notifier)) event_viewer = EventViewerTool(log_handler) window.addTool(event_viewer) diff --git a/src/ert/gui/simulation/experiment_panel.py b/src/ert/gui/simulation/experiment_panel.py index 3e0457f3337..850e513f5ca 100644 --- a/src/ert/gui/simulation/experiment_panel.py +++ b/src/ert/gui/simulation/experiment_panel.py @@ -20,9 +20,7 @@ QWidget, ) -from ert.enkf_main import EnKFMain from ert.gui.ertnotifier import ErtNotifier -from ert.libres_facade import LibresFacade from ert.mode_definitions import ( ENSEMBLE_SMOOTHER_MODE, ES_MDA_MODE, @@ -30,6 +28,7 @@ ) from ert.run_models.model_factory import create_model +from ...config import ErtConfig from .ensemble_experiment_panel import EnsembleExperimentPanel from .ensemble_smoother_panel import EnsembleSmootherPanel from .evaluate_ensemble_panel import EvaluateEnsemblePanel @@ -44,12 +43,17 @@ class ExperimentPanel(QWidget): - def __init__(self, ert: EnKFMain, notifier: ErtNotifier, config_file: str): + def __init__( + self, + config: ErtConfig, + notifier: ErtNotifier, + config_file: str, + ensemble_size: int, + ): QWidget.__init__(self) self._notifier = notifier - self.ert = ert - self.facade = LibresFacade(ert) - ensemble_size = self.facade.get_ensemble_size() + self.config = config + run_path = config.model_config.runpath_format_string self._config_file = config_file self.setObjectName("experiment_panel") @@ -94,38 +98,35 @@ def __init__(self, ert: EnKFMain, notifier: ErtNotifier, config_file: str): self._experiment_widgets = OrderedDict() self.addExperimentConfigPanel( - SingleTestRunPanel(self.facade.run_path, notifier), + SingleTestRunPanel(run_path, notifier), True, ) self.addExperimentConfigPanel( - EnsembleExperimentPanel(ensemble_size, self.facade.run_path, notifier), + EnsembleExperimentPanel(ensemble_size, run_path, notifier), True, ) self.addExperimentConfigPanel( - EvaluateEnsemblePanel(ensemble_size, self.facade.run_path, notifier), + EvaluateEnsemblePanel(ensemble_size, run_path, notifier), True, ) - config = self.facade.config experiment_type_valid = ( config.ensemble_config.parameter_configs and config.observations ) - analysis_config = self.facade.config.analysis_config + analysis_config = config.analysis_config self.addExperimentConfigPanel( - EnsembleSmootherPanel( - analysis_config, self.facade.run_path, notifier, ensemble_size - ), + EnsembleSmootherPanel(analysis_config, run_path, notifier, ensemble_size), experiment_type_valid, ) self.addExperimentConfigPanel( MultipleDataAssimilationPanel( - analysis_config, self.facade.run_path, notifier, ensemble_size + analysis_config, run_path, notifier, ensemble_size ), experiment_type_valid, ) self.addExperimentConfigPanel( IteratedEnsembleSmootherPanel( - analysis_config, self.facade.run_path, notifier, ensemble_size + analysis_config, run_path, notifier, ensemble_size ), experiment_type_valid, ) @@ -199,11 +200,10 @@ def run_experiment(self) -> None: ): abort = False QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor) - config = self.facade.config event_queue = SimpleQueue() try: model = create_model( - config, + self.config, self._notifier.storage, args, event_queue, @@ -287,7 +287,7 @@ def run_experiment(self) -> None: event_queue, self._notifier, self.parent(), - output_path=self.ert.ert_config.analysis_config.log_path, + output_path=self.config.analysis_config.log_path, ) self.run_button.setEnabled(False) self.run_button.setText(EXPERIMENT_IS_RUNNING_BUTTON_MESSAGE) diff --git a/src/ert/gui/tools/export/export_tool.py b/src/ert/gui/tools/export/export_tool.py index 1447b73224a..f69f09f1e45 100644 --- a/src/ert/gui/tools/export/export_tool.py +++ b/src/ert/gui/tools/export/export_tool.py @@ -1,10 +1,15 @@ +from __future__ import annotations + import logging +from typing import TYPE_CHECKING from weakref import ref from qtpy.QtGui import QIcon from qtpy.QtWidgets import QMessageBox -from ert.enkf_main import EnKFMain +if TYPE_CHECKING: + from ert.config import ErtConfig + from ert.gui.ertnotifier import ErtNotifier from ert.gui.ertwidgets.closabledialog import ClosableDialog from ert.gui.tools import Tool @@ -13,11 +18,16 @@ class ExportTool(Tool): - def __init__(self, ert: EnKFMain, notifier: ErtNotifier): + def __init__(self, config: ErtConfig, notifier: ErtNotifier): super().__init__("Export data", QIcon("img:share.svg")) self.__export_widget = None self.__dialog = None - self.__exporter = Exporter(ert, notifier) + self.__exporter = Exporter( + config.workflow_jobs.get("CSV_EXPORT2"), + config.workflow_jobs.get("EXPORT_RUNPATH"), + notifier, + config.runpath_file, + ) self.setEnabled(self.__exporter.is_valid()) def trigger(self) -> None: diff --git a/src/ert/gui/tools/plugins/plugin.py b/src/ert/gui/tools/plugins/plugin.py index a9ede2e5ebc..5e725287587 100644 --- a/src/ert/gui/tools/plugins/plugin.py +++ b/src/ert/gui/tools/plugins/plugin.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List, Optional +import inspect +from typing import TYPE_CHECKING, Any, Dict, List, Optional from ert import ErtScript @@ -8,15 +9,11 @@ from qtpy.QtWidgets import QWidget from ert.config import ErtPlugin, WorkflowJob - from ert.enkf_main import EnKFMain from ert.gui.ertnotifier import ErtNotifier class Plugin: - def __init__( - self, ert: "EnKFMain", notifier: "ErtNotifier", workflow_job: "WorkflowJob" - ): - self.__ert = ert + def __init__(self, notifier: "ErtNotifier", workflow_job: "WorkflowJob"): self.__notifier = notifier self.__workflow_job = workflow_job self.__parent_window: Optional[QWidget] = None @@ -27,11 +24,7 @@ def __init__( def __loadPlugin(self) -> "ErtPlugin": script_obj = ErtScript.loadScriptFromFile(self.__workflow_job.script) - script = script_obj( - self.__ert, - self.__notifier._storage, - ensemble=self.__notifier.current_ensemble, - ) + script = script_obj() return script def getName(self) -> str: @@ -40,13 +33,21 @@ def getName(self) -> str: def getDescription(self) -> str: return self.__description - def getArguments(self) -> List[Any]: + def getArguments(self, fixtures: Dict[str, Any]) -> List[Any]: """ Returns a list of arguments. Either from GUI or from arbitrary code. If the user for example cancels in the GUI a CancelPluginException is raised. """ script = self.__loadPlugin() - return script.getArguments(self.__parent_window) + fixtures["parent"] = self.__parent_window + func_args = inspect.signature(script.getArguments).parameters + arguments = script.insert_fixtures(func_args, fixtures) + + # Part of deprecation + script._ert = fixtures.get("ert_config") + script._ensemble = fixtures.get("ensemble") + script._storage = fixtures.get("storage") + return script.getArguments(*arguments) def setParentWindow(self, parent_window: Optional[QWidget]) -> None: self.__parent_window = parent_window @@ -54,8 +55,8 @@ def setParentWindow(self, parent_window: Optional[QWidget]) -> None: def getParentWindow(self) -> Optional[QWidget]: return self.__parent_window - def ert(self) -> EnKFMain: - return self.__ert + def ert(self) -> None: + raise NotImplementedError("No such property") @property def storage(self): diff --git a/src/ert/gui/tools/plugins/plugin_handler.py b/src/ert/gui/tools/plugins/plugin_handler.py index 21299f0dcf8..736ef18d254 100644 --- a/src/ert/gui/tools/plugins/plugin_handler.py +++ b/src/ert/gui/tools/plugins/plugin_handler.py @@ -4,30 +4,27 @@ if TYPE_CHECKING: from ert.config import WorkflowJob - from ert.enkf_main import EnKFMain from ert.gui.ertnotifier import ErtNotifier class PluginHandler: def __init__( self, - ert: "EnKFMain", notifier: "ErtNotifier", plugin_jobs: List["WorkflowJob"], parent_window, ): - self.__ert = ert self.__plugins = [] for job in plugin_jobs: - plugin = Plugin(self.__ert, notifier, job) + plugin = Plugin(notifier, job) self.__plugins.append(plugin) plugin.setParentWindow(parent_window) self.__plugins = sorted(self.__plugins, key=Plugin.getName) - def ert(self) -> "EnKFMain": - return self.__ert + def ert(self) -> None: + raise NotImplementedError("No such property") def __iter__(self) -> Iterator[Plugin]: index = 0 diff --git a/src/ert/gui/tools/plugins/plugin_runner.py b/src/ert/gui/tools/plugins/plugin_runner.py index 1d53d67ef35..79c15ebd810 100644 --- a/src/ert/gui/tools/plugins/plugin_runner.py +++ b/src/ert/gui/tools/plugins/plugin_runner.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import time -from typing import TYPE_CHECKING, Any, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from _ert.threading import ErtThread from ert.config import CancelPluginException @@ -8,13 +10,16 @@ from .process_job_dialog import ProcessJobDialog if TYPE_CHECKING: + from ert.config import ErtConfig + from .plugin import Plugin class PluginRunner: - def __init__(self, plugin: "Plugin") -> None: + def __init__(self, plugin: "Plugin", ert_config: ErtConfig, storage) -> None: super().__init__() - + self.ert_config = ert_config + self.storage = storage self.__plugin = plugin self.__plugin_finished_callback = lambda: None @@ -27,16 +32,22 @@ def run(self) -> None: try: plugin = self.__plugin - arguments = plugin.getArguments() + arguments = plugin.getArguments( + fixtures={"storage": self.storage, "ert_config": self.ert_config} + ) dialog = ProcessJobDialog(plugin.getName(), plugin.getParentWindow()) dialog.setObjectName("process_job_dialog") dialog.cancelConfirmed.connect(self.cancel) - + fixtures = { + k: getattr(self, k) + for k in ["storage", "ert_config"] + if getattr(self, k) + } workflow_job_thread = ErtThread( name="ert_gui_workflow_job_thread", target=self.__runWorkflowJob, - args=(plugin, arguments), + args=(arguments, fixtures), daemon=True, should_raise=False, ) @@ -56,11 +67,9 @@ def run(self) -> None: print("Plugin cancelled before execution!") def __runWorkflowJob( - self, plugin: "Plugin", arguments: Optional[List[Any]] - ) -> None: - self.__result = self._runner.run( - plugin.ert(), plugin.storage, plugin.ensemble, arguments - ) + self, arguments: Optional[List[Any]], fixtures: Dict[str, Any] + ): + self.__result = self._runner.run(arguments, fixtures=fixtures) def __pollRunner(self, dialog: ProcessJobDialog) -> None: self.wait() diff --git a/src/ert/gui/tools/plugins/plugins_tool.py b/src/ert/gui/tools/plugins/plugins_tool.py index e8a076e3ff2..00e9302392f 100644 --- a/src/ert/gui/tools/plugins/plugins_tool.py +++ b/src/ert/gui/tools/plugins/plugins_tool.py @@ -10,13 +10,19 @@ from .plugin_runner import PluginRunner if TYPE_CHECKING: + from ert.config import ErtConfig from ert.gui.ertnotifier import ErtNotifier from .plugin_handler import PluginHandler class PluginsTool(Tool): - def __init__(self, plugin_handler: PluginHandler, notifier: ErtNotifier) -> None: + def __init__( + self, + plugin_handler: PluginHandler, + notifier: ErtNotifier, + ert_config: ErtConfig, + ) -> None: enabled = len(plugin_handler) > 0 self.notifier = notifier super().__init__( @@ -30,7 +36,7 @@ def __init__(self, plugin_handler: PluginHandler, notifier: ErtNotifier) -> None menu = QMenu() for plugin in plugin_handler: - plugin_runner = PluginRunner(plugin) + plugin_runner = PluginRunner(plugin, ert_config, notifier.storage) plugin_runner.setPluginFinishedCallback(self.trigger) self.__plugins[plugin] = plugin_runner diff --git a/src/ert/gui/tools/run_analysis/run_analysis_tool.py b/src/ert/gui/tools/run_analysis/run_analysis_tool.py index f04126fc33a..0f063dbf47b 100644 --- a/src/ert/gui/tools/run_analysis/run_analysis_tool.py +++ b/src/ert/gui/tools/run_analysis/run_analysis_tool.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import functools import uuid from contextlib import contextmanager -from typing import Iterator, Optional +from typing import TYPE_CHECKING, Iterator, Optional import numpy as np from qtpy.QtCore import QObject, Qt, QThread, Signal, Slot @@ -10,7 +12,7 @@ from ert.analysis import ErtAnalysisError, smoother_update from ert.analysis.event import AnalysisEvent, AnalysisStatusEvent, AnalysisTimeEvent -from ert.enkf_main import EnKFMain, _seed_sequence +from ert.enkf_main import _seed_sequence from ert.gui.ertnotifier import ErtNotifier from ert.gui.ertwidgets.statusdialog import StatusDialog from ert.gui.tools import Tool @@ -18,6 +20,9 @@ from ert.run_models.event import RunModelEvent, RunModelStatusEvent, RunModelTimeEvent from ert.storage import Ensemble +if TYPE_CHECKING: + from ert.config import ErtConfig + class Analyse(QObject): finished = Signal(str, str) @@ -32,12 +37,12 @@ class Analyse(QObject): def __init__( self, - ert: EnKFMain, + ert_config: ErtConfig, target_ensemble: Ensemble, source_ensemble: Ensemble, ) -> None: QObject.__init__(self) - self._ert = ert + self._ert_config = ert_config self._target_ensemble = target_ensemble self._source_ensemble = source_ensemble @@ -46,7 +51,7 @@ def run(self) -> None: """Runs analysis using target and source ensembles. Returns whether the analysis was successful.""" error: Optional[str] = None - config = self._ert.ert_config + config = self._ert_config rng = np.random.default_rng(_seed_sequence(config.random_seed)) update_settings = config.analysis_config.observation_settings update_id = uuid.uuid4() @@ -88,9 +93,9 @@ def send_smoother_event(self, run_id: uuid.UUID, event: AnalysisEvent) -> None: class RunAnalysisTool(Tool): - def __init__(self, ert: EnKFMain, notifier: ErtNotifier) -> None: + def __init__(self, config: ErtConfig, notifier: ErtNotifier) -> None: super().__init__("Run analysis", QIcon("img:formula.svg")) - self.ert = ert + self.ert_config = config self.notifier = notifier self._run_widget: Optional[RunAnalysisPanel] = None self._dialog: Optional[StatusDialog] = None @@ -100,8 +105,8 @@ def __init__(self, ert: EnKFMain, notifier: ErtNotifier) -> None: def trigger(self) -> None: if self._run_widget is None: self._run_widget = RunAnalysisPanel( - self.ert.ert_config.analysis_config.es_module, - self.ert.ert_config.model_config.num_realizations, + self.ert_config.analysis_config.es_module, + self.ert_config.model_config.num_realizations, self.notifier, ) if self._dialog is None: @@ -187,7 +192,7 @@ def _init_analyse(self, source_ensemble: Ensemble, target: str) -> None: ) self._analyse = Analyse( - self.ert, + self.ert_config, target_ensemble, source_ensemble, ) diff --git a/src/ert/gui/tools/workflows/run_workflow_widget.py b/src/ert/gui/tools/workflows/run_workflow_widget.py index 339a3f4dac0..5c258784136 100644 --- a/src/ert/gui/tools/workflows/run_workflow_widget.py +++ b/src/ert/gui/tools/workflows/run_workflow_widget.py @@ -21,6 +21,7 @@ from ert.job_queue import WorkflowRunner if TYPE_CHECKING: + from ert.config import ErtConfig from ert.gui.ertnotifier import ErtNotifier @@ -29,8 +30,8 @@ class RunWorkflowWidget(QWidget): workflowFailed = Signal() workflowKilled = Signal() - def __init__(self, ert, notifier: ErtNotifier): - self.ert = ert + def __init__(self, config: ErtConfig, notifier: ErtNotifier): + self.config = config self.storage = notifier.storage self.notifier = notifier QWidget.__init__(self) @@ -38,9 +39,7 @@ def __init__(self, ert, notifier: ErtNotifier): layout = QFormLayout() self._workflow_combo = QComboBox() - self._workflow_combo.addItems( - sorted(ert.ert_config.workflows.keys(), key=str.lower) - ) + self._workflow_combo.addItems(sorted(config.workflows.keys(), key=str.lower)) layout.addRow("Workflow", self._workflow_combo) @@ -106,7 +105,7 @@ def cancelWorkflow(self) -> None: def getCurrentWorkflowName(self) -> List[str]: index = self._workflow_combo.currentIndex() - return (sorted(self.ert.ert_config.workflows.keys(), key=str.lower))[index] + return (sorted(self.config.workflows.keys(), key=str.lower))[index] def startWorkflow(self) -> None: self._running_workflow_dialog = WorkflowDialog( @@ -121,12 +120,12 @@ def startWorkflow(self) -> None: should_raise=False, ) - workflow = self.ert.ert_config.workflows[self.getCurrentWorkflowName()] + workflow = self.config.workflows[self.getCurrentWorkflowName()] self._workflow_runner = WorkflowRunner( workflow, - self.ert, storage=self.storage, ensemble=self.source_ensemble_selector.currentData(), + ert_config=self.config, ) self._workflow_runner.run() diff --git a/src/ert/gui/tools/workflows/workflows_tool.py b/src/ert/gui/tools/workflows/workflows_tool.py index 55edb8e9bf3..ba0c22e9682 100644 --- a/src/ert/gui/tools/workflows/workflows_tool.py +++ b/src/ert/gui/tools/workflows/workflows_tool.py @@ -9,15 +9,15 @@ from ert.gui.tools.workflows import RunWorkflowWidget if TYPE_CHECKING: - from ert.enkf_main import EnKFMain + from ert.config import ErtConfig from ert.gui.ertnotifier import ErtNotifier class WorkflowsTool(Tool): - def __init__(self, ert: EnKFMain, notifier: ErtNotifier) -> None: + def __init__(self, config: ErtConfig, notifier: ErtNotifier) -> None: self.notifier = notifier - self.ert = ert - enabled = len(ert.ert_config.workflows) > 0 + self.config = config + enabled = len(config.workflows) > 0 super().__init__( "Run workflow", QIcon("img:playlist_play.svg"), @@ -25,7 +25,7 @@ def __init__(self, ert: EnKFMain, notifier: ErtNotifier) -> None: ) def trigger(self) -> None: - run_workflow_widget = RunWorkflowWidget(self.ert, self.notifier) + run_workflow_widget = RunWorkflowWidget(self.config, self.notifier) dialog = ClosableDialog("Run workflow", run_workflow_widget, self.parent()) dialog.exec_() self.notifier.emitErtChange() # workflow may have added new cases. diff --git a/src/ert/job_queue/workflow_runner.py b/src/ert/job_queue/workflow_runner.py index cf33e826bdf..900b07e5146 100644 --- a/src/ert/job_queue/workflow_runner.py +++ b/src/ert/job_queue/workflow_runner.py @@ -7,10 +7,9 @@ from typing_extensions import Self -from ert.config import ErtScript, ExternalErtScript, Workflow, WorkflowJob +from ert.config import ErtConfig, ErtScript, ExternalErtScript, Workflow, WorkflowJob if TYPE_CHECKING: - from ert.enkf_main import EnKFMain from ert.storage import Ensemble, Storage @@ -23,13 +22,12 @@ def __init__(self, workflow_job: WorkflowJob): def run( self, - ert: Optional[EnKFMain] = None, - storage: Optional[Storage] = None, - ensemble: Optional[Ensemble] = None, arguments: Optional[List[Any]] = None, + fixtures: Optional[Dict[str, Any]] = None, ) -> Any: if arguments is None: arguments = [] + fixtures = {} if fixtures is None else fixtures self.__running = True if self.job.min_args and len(arguments) < self.job.min_args: raise ValueError( @@ -44,7 +42,7 @@ def run( ) if self.job.ert_script is not None: - self.__script = self.job.ert_script(ert, storage, ensemble) + self.__script = self.job.ert_script() if self.job.stop_on_fail is not None: self.stop_on_fail = self.job.stop_on_fail elif self.__script is not None: @@ -52,8 +50,6 @@ def run( elif not self.job.internal: self.__script = ExternalErtScript( - ert, # type: ignore - storage, # type: ignore self.job.executable, # type: ignore ) @@ -63,8 +59,7 @@ def run( else: raise UserWarning("Unknown script type!") result = self.__script.initializeAndRun( # type: ignore - self.job.argument_types(), - arguments, + self.job.argument_types(), arguments, fixtures=fixtures ) self.__running = False @@ -114,14 +109,14 @@ class WorkflowRunner: def __init__( self, workflow: Workflow, - ert: Optional[EnKFMain] = None, storage: Optional[Storage] = None, ensemble: Optional[Ensemble] = None, + ert_config: Optional[ErtConfig] = None, ) -> None: self.__workflow = workflow - self._ert = ert - self._storage = storage - self._ensemble = ensemble + self.storage = storage + self.ensemble = ensemble + self.ert_config = ert_config self.__workflow_result: Optional[bool] = None self._workflow_executor = futures.ThreadPoolExecutor(max_workers=1) @@ -157,13 +152,18 @@ def run_blocking(self) -> None: # Reset status self.__status = {} self.__running = True + fixtures = { + k: getattr(self, k) + for k in ["storage", "ensemble", "ert_config"] + if getattr(self, k) + } for job, args in self.__workflow: jobrunner = WorkflowJobRunner(job) self.__current_job = jobrunner if not self.__cancelled: logger.info(f"Workflow job {jobrunner.name} starting") - jobrunner.run(self._ert, self._storage, self._ensemble, args) + jobrunner.run(args, fixtures=fixtures) self.__status[jobrunner.name] = { "stdout": jobrunner.stdoutdata(), "stderr": jobrunner.stderrdata(), diff --git a/src/ert/libres_facade.py b/src/ert/libres_facade.py index 2d4c5b96cae..7a498beb9b0 100644 --- a/src/ert/libres_facade.py +++ b/src/ert/libres_facade.py @@ -3,6 +3,7 @@ import json import logging import time +import warnings from multiprocessing.pool import ThreadPool from typing import ( TYPE_CHECKING, @@ -13,7 +14,6 @@ List, Optional, Tuple, - Union, ) import numpy as np @@ -31,7 +31,7 @@ from ert.data._measured_data import ObservationError, ResponseError from ert.load_status import LoadResult, LoadStatus -from .enkf_main import EnKFMain, ensemble_context +from .enkf_main import ensemble_context from .shared.plugins import ErtPluginContext _logger = logging.getLogger(__name__) @@ -60,16 +60,8 @@ class LibresFacade: commonly used in other project. It is part of the public interface of ert, and as such changes here should not be taken lightly.""" - def __init__(self, enkf_main: Union[EnKFMain, ErtConfig]): - # EnKFMain is more or less just a facade for the configuration at this - # point, so in the process of removing it altogether it is easier - # if we allow the facade to created with both EnKFMain and ErtConfig - if isinstance(enkf_main, EnKFMain): - self._enkf_main = enkf_main - self.config: ErtConfig = enkf_main.ert_config - else: - self._enkf_main = EnKFMain(enkf_main) - self.config = enkf_main + def __init__(self, ert_config: ErtConfig, _: Any = None): + self.config = ert_config self.update_snapshots: Dict[str, SmootherSnapshot] = {} self.update_configuration = None @@ -283,10 +275,20 @@ def run_ertscript( # type: ignore storage: Storage, ensemble: Ensemble, *args: Optional[Any], - **kwargs: Optional[Any], ) -> Any: - return ertscript(self._enkf_main, storage, ensemble=ensemble).run( - *args, **kwargs + warnings.warn( + "run_ertscript is deprecated, use the workflow runner", + DeprecationWarning, + stacklevel=1, + ) + return ertscript().initializeAndRun( + [], + argument_values=args, + fixtures={ + "ert_config": self.config, + "ensemble": ensemble, + "storage": storage, + }, ) @classmethod @@ -295,10 +297,8 @@ def from_config_file( ) -> "LibresFacade": with ErtPluginContext() as ctx: return cls( - EnKFMain( - ErtConfig.with_plugins( - forward_model_step_classes=ctx.plugin_manager.forward_model_steps - ).from_file(config_file), - read_only, - ) + ErtConfig.with_plugins( + forward_model_step_classes=ctx.plugin_manager.forward_model_steps + ).from_file(config_file), + read_only, ) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 111a7ed1685..e78709b4e8b 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -32,9 +32,8 @@ AnalysisErrorEvent, ) from ert.config import ErtConfig, HookRuntime, QueueSystem -from ert.enkf_main import EnKFMain, _seed_sequence, create_run_path +from ert.enkf_main import _seed_sequence, create_run_path from ert.ensemble_evaluator import ( - Ensemble, EnsembleBuilder, EnsembleEvaluator, EvaluatorServerConfig, @@ -64,8 +63,9 @@ from ert.mode_definitions import MODULE_MODE from ert.run_context import RunContext from ert.runpaths import Runpaths -from ert.storage import Storage +from ert.storage import Ensemble, Storage +from ..job_queue import WorkflowRunner from .event import ( RunModelDataEvent, RunModelErrorEvent, @@ -79,6 +79,7 @@ if TYPE_CHECKING: from ert.config import QueueConfig + from ert.ensemble_evaluator import Ensemble as EEEnsemble from ert.run_models.run_arguments import RunArgumentsType StatusEvents = Union[ @@ -167,8 +168,7 @@ def __init__( self._completed_realizations_mask: List[bool] = [] self.support_restart: bool = True self.ert_config = config - self.ert = EnKFMain(config) - self.facade = LibresFacade(self.ert) + self.facade = LibresFacade(self.ert_config) self._storage = storage self._simulation_arguments = simulation_arguments self.reset() @@ -581,7 +581,7 @@ def run_ensemble_evaluator( def _build_ensemble( self, run_context: RunContext, - ) -> "Ensemble": + ) -> EEEnsemble: builder = EnsembleBuilder().set_legacy_dependencies( self._queue_config, self._simulation_arguments.stop_long_running, @@ -641,6 +641,14 @@ def validate(self) -> None: f"({min_realization_count})" ) + def run_workflows( + self, runtime: HookRuntime, storage: Storage, ensemble: Ensemble + ) -> None: + for workflow in self.ert_config.hooked_workflows[runtime]: + WorkflowRunner( + workflow, storage, ensemble, ert_config=self.ert_config + ).run_blocking() + def _evaluate_and_postprocess( self, run_context: RunContext, @@ -654,7 +662,7 @@ def _evaluate_and_postprocess( phase_string = f"Pre processing for iteration: {iteration}" self.setPhaseName(phase_string) - self.ert.runWorkflows( + self.run_workflows( HookRuntime.PRE_SIMULATION, self._storage, run_context.ensemble ) @@ -696,7 +704,7 @@ def _evaluate_and_postprocess( phase_string = f"Post processing for iteration: {iteration}" self.setPhaseName(phase_string) - self.ert.runWorkflows( + self.run_workflows( HookRuntime.POST_SIMULATION, self._storage, run_context.ensemble ) diff --git a/src/ert/run_models/ensemble_smoother.py b/src/ert/run_models/ensemble_smoother.py index dbb3f514276..e85432eb07c 100644 --- a/src/ert/run_models/ensemble_smoother.py +++ b/src/ert/run_models/ensemble_smoother.py @@ -105,10 +105,10 @@ def run_experiment( ) self.setPhaseName("Running ES update step") - self.ert.runWorkflows( + self.run_workflows( HookRuntime.PRE_FIRST_UPDATE, self._storage, prior_context.ensemble ) - self.ert.runWorkflows( + self.run_workflows( HookRuntime.PRE_UPDATE, self._storage, prior_context.ensemble ) @@ -155,8 +155,8 @@ def run_experiment( f"Analysis of experiment failed with the following error: {e}" ) from e - self.ert.runWorkflows( - HookRuntime.POST_UPDATE, self._storage, posterior_context.ensemble + self.run_workflows( + HookRuntime.POST_UPDATE, self._storage, prior_context.ensemble ) self._evaluate_and_postprocess(posterior_context, evaluator_server_config) diff --git a/src/ert/run_models/iterated_ensemble_smoother.py b/src/ert/run_models/iterated_ensemble_smoother.py index cee090a8689..19939a70f5c 100644 --- a/src/ert/run_models/iterated_ensemble_smoother.py +++ b/src/ert/run_models/iterated_ensemble_smoother.py @@ -83,7 +83,7 @@ def analyzeStep( initial_mask: npt.NDArray[np.bool_], ) -> None: self.setPhaseName("Pre processing update...") - self.ert.runWorkflows(HookRuntime.PRE_UPDATE, self._storage, prior_storage) + self.run_workflows(HookRuntime.PRE_UPDATE, self._storage, prior_storage) try: smoother_snapshot, self.sies_smoother = iterative_smoother_update( prior_storage, @@ -106,7 +106,7 @@ def analyzeStep( ) from e self.setPhaseName("Post processing update...") - self.ert.runWorkflows(HookRuntime.POST_UPDATE, self._storage, posterior_storage) + self.run_workflows(HookRuntime.POST_UPDATE, self._storage, posterior_storage) def run_experiment( self, evaluator_server_config: EvaluatorServerConfig @@ -159,7 +159,7 @@ def run_experiment( ) self._evaluate_and_postprocess(prior_context, evaluator_server_config) - self.ert.runWorkflows( + self.run_workflows( HookRuntime.PRE_FIRST_UPDATE, self._storage, prior_context.ensemble ) for current_iter in range(1, iteration_count + 1): diff --git a/src/ert/run_models/multiple_data_assimilation.py b/src/ert/run_models/multiple_data_assimilation.py index e237935ee77..bafe74adb7d 100644 --- a/src/ert/run_models/multiple_data_assimilation.py +++ b/src/ert/run_models/multiple_data_assimilation.py @@ -149,10 +149,8 @@ def run_experiment( ) ) if is_first_iteration: - self.ert.runWorkflows( - HookRuntime.PRE_FIRST_UPDATE, self._storage, prior - ) - self.ert.runWorkflows(HookRuntime.PRE_UPDATE, self._storage, prior) + self.run_workflows(HookRuntime.PRE_FIRST_UPDATE, self._storage, prior) + self.run_workflows(HookRuntime.PRE_UPDATE, self._storage, prior) self.send_event( RunModelStatusEvent( @@ -182,8 +180,8 @@ def run_experiment( posterior_context, weight=weight, ) - self.ert.runWorkflows( - HookRuntime.POST_UPDATE, self._storage, posterior_context.ensemble + self.run_workflows( + HookRuntime.POST_UPDATE, self._storage, prior_context.ensemble ) self._evaluate_and_postprocess(posterior_context, evaluator_server_config) diff --git a/src/ert/shared/exporter.py b/src/ert/shared/exporter.py index e41a3e880b3..09c163bd317 100644 --- a/src/ert/shared/exporter.py +++ b/src/ert/shared/exporter.py @@ -1,76 +1,74 @@ +from __future__ import annotations + import logging import re from pathlib import Path -from typing import Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional import pandas as pd from ert.analysis.event import DataSection -from ert.enkf_main import EnKFMain + +if TYPE_CHECKING: + from ert.config import WorkflowJob + from ert.gui.ertnotifier import ErtNotifier from ert.job_queue import WorkflowJobRunner -from ert.libres_facade import LibresFacade logger = logging.getLogger(__name__) class Exporter: - def __init__(self, ert: EnKFMain, notifier: ErtNotifier): - self.ert = ert - self.facade = LibresFacade(ert) - self._export_job = "CSV_EXPORT2" - self._runpath_job = "EXPORT_RUNPATH" + def __init__( + self, + export_job: Optional[WorkflowJob], + runpath_job: Optional[WorkflowJob], + notifier: ErtNotifier, + runpath_file: str, + ): + self.runpath_file = runpath_file + self.export_job = export_job + self.runpath_job = runpath_job self._notifier = notifier def is_valid(self) -> bool: - export_job = self.facade.get_workflow_job(self._export_job) - runpath_job = self.facade.get_workflow_job(self._runpath_job) - - if export_job is None: - logger.error( - f"Export not available because {self._export_job} is not installed." - ) + if self.export_job is None: + logger.error("Export not available because export_job is not installed.") return False - if runpath_job is None: - logger.error( - f"Export not available because {self._runpath_job} is not installed." - ) + if self.runpath_job is None: + logger.error("Export not available because runpath_job is not installed.") return False return True def run_export(self, parameters: Dict[str, Any]) -> None: - export_job = self.facade.get_workflow_job(self._export_job) - if export_job is None: - raise UserWarning(f"Could not find {self._export_job} job") - runpath_job = self.facade.get_workflow_job(self._runpath_job) - if runpath_job is None: - raise UserWarning(f"Could not find {self._runpath_job} job") + if self.export_job is None: + raise UserWarning("Could not find export_job job") + if self.runpath_job is None: + raise UserWarning("Could not find runpath_job job") - runpath_job_runner = WorkflowJobRunner(runpath_job) + runpath_job_runner = WorkflowJobRunner(self.runpath_job) runpath_job_runner.run( - ert=self.ert, - storage=self._notifier.storage, + fixtures={"storage": self._notifier.storage}, arguments=[], ) if runpath_job_runner.hasFailed(): - raise UserWarning(f"Failed to execute {self._runpath_job}") + raise UserWarning(f"Failed to execute {self.runpath_job.name}") - export_job_runner = WorkflowJobRunner(export_job) + export_job_runner = WorkflowJobRunner(self.export_job) user_warn = export_job_runner.run( - ert=self.ert, - storage=self._notifier.storage, + fixtures={"storage": self._notifier.storage}, arguments=[ - str(self.ert.ert_config.runpath_file), + str(self.runpath_file), parameters["output_file"], parameters["time_index"], parameters["column_keys"], ], ) if export_job_runner.hasFailed(): - raise UserWarning(f"Failed to execute {self._export_job}\n{user_warn}") + raise UserWarning(f"Failed to execute {self.export_job.name}\n{user_warn}") def csv_event_to_report(name: str, data: DataSection, output_path: Path) -> None: diff --git a/src/ert/shared/hook_implementations/workflows/export_misfit_data.py b/src/ert/shared/hook_implementations/workflows/export_misfit_data.py index 4f38cb3c42f..1eb31a8024b 100644 --- a/src/ert/shared/hook_implementations/workflows/export_misfit_data.py +++ b/src/ert/shared/hook_implementations/workflows/export_misfit_data.py @@ -1,8 +1,16 @@ -from typing import Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List + +import pandas as pd from ert import ErtScript from ert.exceptions import StorageError +if TYPE_CHECKING: + from ert.config import ErtConfig + from ert.storage import Ensemble + class ExportMisfitDataJob(ErtScript): """ @@ -15,23 +23,19 @@ class ExportMisfitDataJob(ErtScript): ((response_value - observation_data) / observation_std)**2 """ - def run(self, target_file: Optional[str] = None) -> None: - ert = self.ert() + def run( + self, ert_config: ErtConfig, ensemble: Ensemble, workflow_args: List[Any] + ) -> None: + target_file = "misfit.hdf" if not workflow_args else workflow_args[0] - if target_file is None: - target_file = "misfit.hdf" - if self.ensemble is None: - raise StorageError("No responses loaded") - - realizations = self.ensemble.get_realization_with_responses() + realizations = ensemble.get_realization_with_responses() from ert import LibresFacade - facade = LibresFacade(ert) - misfit = facade.load_all_misfit_data(self.ensemble) + facade = LibresFacade(ert_config) + misfit = facade.load_all_misfit_data(ensemble) if realizations.size == 0 or misfit.empty: raise StorageError("No responses loaded") - - misfit.columns = [val.split(":")[1] for val in misfit.columns] + misfit.columns = pd.Index([val.split(":")[1] for val in misfit.columns]) misfit = misfit.drop("TOTAL", axis=1) misfit.to_hdf(target_file, key="misfit", mode="w") diff --git a/src/ert/shared/hook_implementations/workflows/export_runpath.py b/src/ert/shared/hook_implementations/workflows/export_runpath.py index 028ed3eb92b..58c4b3125a0 100644 --- a/src/ert/shared/hook_implementations/workflows/export_runpath.py +++ b/src/ert/shared/hook_implementations/workflows/export_runpath.py @@ -1,9 +1,14 @@ -from typing import List, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Tuple from ert.config import ErtScript from ert.runpaths import Runpaths from ert.validation import rangestring_to_mask +if TYPE_CHECKING: + from ert.config import ErtConfig + class ExportRunpathJob(ErtScript): """The EXPORT_RUNPATH workflow job writes the runpath file. @@ -28,25 +33,32 @@ class ExportRunpathJob(ErtScript): file. """ - def run(self, *args: str) -> None: - _args = " ".join(args).split() # Make sure args is a list of words - config = self.ert().ert_config + def run(self, ert_config: ErtConfig, workflow_args: List[Any]) -> None: + _args = " ".join(workflow_args).split() # Make sure args is a list of words run_paths = Runpaths( - jobname_format=config.model_config.jobname_format_string, - runpath_format=config.model_config.runpath_format_string, - filename=str(config.runpath_file), - substitution_list=config.substitution_list, + jobname_format=ert_config.model_config.jobname_format_string, + runpath_format=ert_config.model_config.runpath_format_string, + filename=str(ert_config.runpath_file), + substitution_list=ert_config.substitution_list, + ) + run_paths.write_runpath_list( + *self.get_ranges( + _args, + ert_config.analysis_config.num_iterations, + ert_config.model_config.num_realizations, + ) ) - run_paths.write_runpath_list(*self.get_ranges(_args)) - def get_ranges(self, args: List[str]) -> Tuple[List[int], List[int]]: - realizations_rangestring, iterations_rangestring = self._get_rangestrings(args) + def get_ranges( + self, args: List[str], number_of_iterations: int, number_of_realizations: int + ) -> Tuple[List[int], List[int]]: + realizations_rangestring, iterations_rangestring = self._get_rangestrings( + args, number_of_realizations + ) return ( + self._list_from_rangestring(iterations_rangestring, number_of_iterations), self._list_from_rangestring( - iterations_rangestring, self.number_of_iterations - ), - self._list_from_rangestring( - realizations_rangestring, self.number_of_realizations + realizations_rangestring, number_of_realizations ), ) @@ -58,21 +70,15 @@ def _list_from_rangestring(rangestring: str, size: int) -> List[int]: mask = rangestring_to_mask(rangestring, size) return [i for i, flag in enumerate(mask) if flag] - def _get_rangestrings(self, args: List[str]) -> Tuple[str, str]: + def _get_rangestrings( + self, args: List[str], number_of_realizations: int + ) -> Tuple[str, str]: if not args: return ( - f"0-{self.number_of_realizations-1}", + f"0-{number_of_realizations-1}", "0-0", # weird default behavior, kept for backwards compatability ) if "|" not in args: raise ValueError("Expected | in EXPORT_RUNPATH arguments") delimiter = args.index("|") return " ".join(args[:delimiter]), " ".join(args[delimiter + 1 :]) - - @property - def number_of_realizations(self) -> int: - return self.ert().ert_config.model_config.num_realizations - - @property - def number_of_iterations(self) -> int: - return self.ert().ert_config.analysis_config.num_iterations diff --git a/src/ert/shared/share/ert/workflows/jobs/internal-gui/scripts/csv_export.py b/src/ert/shared/share/ert/workflows/jobs/internal-gui/scripts/csv_export.py index 1aa00601936..6ed720f6752 100644 --- a/src/ert/shared/share/ert/workflows/jobs/internal-gui/scripts/csv_export.py +++ b/src/ert/shared/share/ert/workflows/jobs/internal-gui/scripts/csv_export.py @@ -81,18 +81,21 @@ def inferIterationNumber(ensemble_name): def run( self, - output_file, - ensemble_list=None, - design_matrix_path=None, - infer_iteration=True, - drop_const_cols=False, + ert_config, + storage, + workflow_args, ): + output_file = workflow_args[0] + ensemble_list = None if len(workflow_args) < 2 else workflow_args[1] + design_matrix_path = None if len(workflow_args) < 3 else workflow_args[2] + _ = True if len(workflow_args) < 4 else workflow_args[3] + drop_const_cols = False if len(workflow_args) < 5 else workflow_args[4] ensembles = [] - facade = LibresFacade(self.ert()) + facade = LibresFacade(ert_config) if ensemble_list is not None: if ensemble_list.strip() == "*": - ensembles = self.getAllEnsembleList() + ensembles = self.getAllEnsembleList(storage) else: ensembles = ensemble_list.split(",") @@ -157,7 +160,7 @@ def run( ) return export_info - def getArguments(self, parent=None): + def getArguments(self, parent, ert_config, storage): from ert.gui.ertwidgets.customdialog import CustomDialog from ert.gui.ertwidgets.listeditbox import ListEditBox from ert.gui.ertwidgets.models.path_model import PathModel @@ -166,21 +169,18 @@ def getArguments(self, parent=None): description = "The CSV export requires some information before it starts:" dialog = CustomDialog("CSV Export", description, parent) - default_csv_output_path = self.get_context_value( - "", default="output.csv" - ) + subs_list = ert_config.substitution_list + default_csv_output_path = subs_list.get("", "output.csv") output_path_model = PathModel(default_csv_output_path) output_path_chooser = PathChooser(output_path_model) - design_matrix_default = self.get_context_value( - "", default="" - ) + design_matrix_default = subs_list.get("", "") design_matrix_path_model = PathModel( design_matrix_default, is_required=False, must_exist=True ) design_matrix_path_chooser = PathChooser(design_matrix_path_model) - list_edit = ListEditBox(self.getAllEnsembleList()) + list_edit = ListEditBox(self.getAllEnsembleList(storage)) infer_iteration_check = QCheckBox() infer_iteration_check.setChecked(True) @@ -219,14 +219,9 @@ def getArguments(self, parent=None): raise CancelPluginException("User cancelled!") - def get_context_value(self, name, default): - context = self.ert().ert_config.substitution_list - if name in context: - return context[name] - return default - - def getAllEnsembleList(self): + @staticmethod + def getAllEnsembleList(storage): all_ensemble_list = [ - ensemble.name for ensemble in self.storage.ensembles if ensemble.has_data() + ensemble.name for ensemble in storage.ensembles if ensemble.has_data() ] return all_ensemble_list diff --git a/src/ert/shared/share/ert/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py b/src/ert/shared/share/ert/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py index 4fbbe54897f..62bfda7163b 100644 --- a/src/ert/shared/share/ert/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py +++ b/src/ert/shared/share/ert/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py @@ -103,11 +103,8 @@ def inferIterationNumber(ensemble_name): def run( self, - output_file, - trajectory_path, - ensemble_list=None, - infer_iteration=True, - drop_const_cols=False, + storage, + workflow_args, ): """The run method will export the RFT's for all wells and all ensembles. @@ -119,6 +116,12 @@ def run( or $trajectory_path/$WELL_R.txt """ + output_file = workflow_args[0] + trajectory_path = workflow_args[1] + ensemble_list = None if len(workflow_args) < 3 else workflow_args[2] + _ = True if len(workflow_args) < 4 else workflow_args[3] + drop_const_cols = False if len(workflow_args) < 5 else workflow_args[4] + wells = set() ensemble_names = [] @@ -134,7 +137,7 @@ def run( ensemble_data = [] try: - ensemble = self.storage.get_ensemble_by_name(ensemble_name) + ensemble = storage.get_ensemble_by_name(ensemble_name) except KeyError as exc: raise UserWarning( f"The ensemble '{ensemble_name}' does not exist!" @@ -233,7 +236,7 @@ def run( ) return export_info - def getArguments(self, parent=None): + def getArguments(self, parent, storage): from ert.gui.ertwidgets import CustomDialog, ListEditBox, PathChooser from ert.gui.ertwidgets.models.path_model import PathModel @@ -251,7 +254,7 @@ def getArguments(self, parent=None): trajectory_chooser = PathChooser(trajectory_model) trajectory_chooser.setObjectName("trajectory_chooser") - all_ensemble_list = [ensemble.name for ensemble in self.storage.ensembles] + all_ensemble_list = [ensemble.name for ensemble in storage.ensembles] list_edit = ListEditBox(all_ensemble_list) list_edit.setObjectName("list_of_ensembles") diff --git a/src/ert/simulator/batch_simulator.py b/src/ert/simulator/batch_simulator.py index 12933168a62..e7f049ebccc 100644 --- a/src/ert/simulator/batch_simulator.py +++ b/src/ert/simulator/batch_simulator.py @@ -5,7 +5,6 @@ import numpy as np from ert.config import ErtConfig, ExtParamConfig, GenDataConfig -from ert.enkf_main import EnKFMain from .batch_simulator_context import BatchContext @@ -89,7 +88,6 @@ def callback(*args, **kwargs): raise ValueError("The first argument must be valid ErtConfig instance") self.ert_config = ert_config - self.ert = EnKFMain(self.ert_config) self.control_keys = set(controls.keys()) self.result_keys = set(results) self.callback = callback @@ -250,7 +248,7 @@ def start( itr = 0 mask = np.full(len(case_data), True, dtype=bool) sim_context = BatchContext( - self.result_keys, self.ert, ensemble, mask, itr, case_data + self.result_keys, self.ert_config, ensemble, mask, itr, case_data ) if self.callback: diff --git a/src/ert/simulator/batch_simulator_context.py b/src/ert/simulator/batch_simulator_context.py index f80cec6576e..0677b25c308 100644 --- a/src/ert/simulator/batch_simulator_context.py +++ b/src/ert/simulator/batch_simulator_context.py @@ -17,7 +17,7 @@ import numpy.typing as npt - from ert.enkf_main import EnKFMain + from ert.config import ErtConfig from ert.storage import Ensemble Status = namedtuple("Status", "waiting pending running complete failed") @@ -27,7 +27,7 @@ class BatchContext(SimulationContext): def __init__( self, result_keys: "Iterable[str]", - ert: "EnKFMain", + ert_config: "ErtConfig", fs: Ensemble, mask: npt.NDArray[np.bool_], itr: int, @@ -36,9 +36,9 @@ def __init__( """ Handle which can be used to query status and results for batch simulation. """ - super().__init__(ert, fs, mask, itr, case_data) + super().__init__(ert_config, fs, mask, itr, case_data) self.result_keys = result_keys - self.ert_config = ert.ert_config + self.ert_config = ert_config def join(self) -> None: """ diff --git a/src/ert/simulator/simulation_context.py b/src/ert/simulator/simulation_context.py index 8155992dd08..a82f1d578d9 100644 --- a/src/ert/simulator/simulation_context.py +++ b/src/ert/simulator/simulation_context.py @@ -12,7 +12,7 @@ from ert.config import HookRuntime from ert.enkf_main import create_run_path from ert.ensemble_evaluator import Realization -from ert.job_queue import JobQueue, JobStatus +from ert.job_queue import JobQueue, JobStatus, WorkflowRunner from ert.run_context import RunContext from ert.runpaths import Runpaths from ert.scheduler import Scheduler, create_driver @@ -24,7 +24,7 @@ if TYPE_CHECKING: import numpy.typing as npt - from ert.enkf_main import EnKFMain + from ert.config import ErtConfig from ert.run_arg import RunArg from ert.storage import Ensemble @@ -35,20 +35,20 @@ def _slug(entity: str) -> str: def _run_forward_model( - ert: "EnKFMain", + ert_config: "ErtConfig", job_queue: Union["JobQueue", "Scheduler"], run_context: "RunContext", ) -> None: # run simplestep - asyncio.run(_submit_and_run_jobqueue(ert, job_queue, run_context)) + asyncio.run(_submit_and_run_jobqueue(ert_config, job_queue, run_context)) async def _submit_and_run_jobqueue( - ert: "EnKFMain", + ert_config: "ErtConfig", job_queue: Union["JobQueue", "Scheduler"], run_context: "RunContext", ) -> None: - max_runtime: Optional[int] = ert.ert_config.analysis_config.max_runtime + max_runtime: Optional[int] = ert_config.analysis_config.max_runtime if max_runtime == 0: max_runtime = None for index, run_arg in enumerate(run_context): @@ -57,9 +57,9 @@ async def _submit_and_run_jobqueue( if isinstance(job_queue, JobQueue): job_queue.add_job_from_run_arg( run_arg, - ert.ert_config.queue_config.job_script, + ert_config.queue_config.job_script, max_runtime, - ert.ert_config.preferred_num_cpu, + ert_config.preferred_num_cpu, ) else: realization = Realization( @@ -68,16 +68,14 @@ async def _submit_and_run_jobqueue( active=True, max_runtime=max_runtime, run_arg=run_arg, - num_cpu=ert.ert_config.preferred_num_cpu, - job_script=ert.ert_config.queue_config.job_script, + num_cpu=ert_config.preferred_num_cpu, + job_script=ert_config.queue_config.job_script, ) job_queue.set_realization(realization) required_realizations = 0 - if ert.ert_config.analysis_config.stop_long_running: - required_realizations = ( - ert.ert_config.analysis_config.minimum_required_realizations - ) + if ert_config.analysis_config.stop_long_running: + required_realizations = ert_config.analysis_config.minimum_required_realizations with contextlib.suppress(asyncio.CancelledError): await job_queue.execute(required_realizations) @@ -85,24 +83,24 @@ async def _submit_and_run_jobqueue( class SimulationContext: def __init__( self, - ert: "EnKFMain", + ert_config: "ErtConfig", ensemble: Ensemble, mask: npt.NDArray[np.bool_], itr: int, case_data: List[Tuple[Any, Any]], ): - self._ert = ert + self._ert_config = ert_config self._mask = mask - if FeatureScheduler.is_enabled(ert.ert_config.queue_config.queue_system): - driver = create_driver(ert.ert_config.queue_config) + if FeatureScheduler.is_enabled(ert_config.queue_config.queue_system): + driver = create_driver(ert_config.queue_config) self._job_queue = Scheduler( - driver, max_running=ert.ert_config.queue_config.max_running + driver, max_running=ert_config.queue_config.max_running ) else: - self._job_queue = JobQueue(ert.ert_config.queue_config) + self._job_queue = JobQueue(ert_config.queue_config) # fill in the missing geo_id data - global_substitutions = ert.ert_config.substitution_list + global_substitutions = ert_config.substitution_list global_substitutions[""] = _slug(ensemble.name) for sim_id, (geo_id, _) in enumerate(case_data): if mask[sim_id]: @@ -110,19 +108,18 @@ def __init__( self._run_context = RunContext( ensemble=ensemble, runpaths=Runpaths( - jobname_format=ert.ert_config.model_config.jobname_format_string, - runpath_format=ert.ert_config.model_config.runpath_format_string, - filename=str(ert.ert_config.runpath_file), + jobname_format=ert_config.model_config.jobname_format_string, + runpath_format=ert_config.model_config.runpath_format_string, + filename=str(ert_config.runpath_file), substitution_list=global_substitutions, ), initial_mask=mask, iteration=itr, ) - create_run_path(self._run_context, self._ert.ert_config) - self._ert.runWorkflows( - HookRuntime.PRE_SIMULATION, None, self._run_context.ensemble - ) + create_run_path(self._run_context, ert_config) + for workflow in ert_config.hooked_workflows[HookRuntime.PRE_SIMULATION]: + WorkflowRunner(workflow, None, self._run_context.ensemble).run_blocking() self._sim_thread = self._run_simulations_simple_step() # Wait until the queue is active before we finish the creation @@ -145,7 +142,7 @@ def get_run_args(self, iens: int) -> "RunArg": def _run_simulations_simple_step(self) -> Thread: sim_thread = ErtThread( target=lambda: _run_forward_model( - self._ert, self._job_queue, self._run_context + self._ert_config, self._job_queue, self._run_context ) ) sim_thread.start() diff --git a/tests/performance_tests/test_dark_storage_performance.py b/tests/performance_tests/test_dark_storage_performance.py index 8bcc6cf1001..6f7af741998 100644 --- a/tests/performance_tests/test_dark_storage_performance.py +++ b/tests/performance_tests/test_dark_storage_performance.py @@ -7,7 +7,6 @@ from ert.config import ErtConfig from ert.dark_storage.endpoints import ensembles, experiments, records -from ert.enkf_main import EnKFMain from ert.libres_facade import LibresFacade from ert.storage import open_storage @@ -127,8 +126,7 @@ def test_direct_dark_performance( with template_config["folder"].as_cwd(): config = ErtConfig.from_file("poly.ert") - ert = EnKFMain(config) - enkf_facade = LibresFacade(ert) + enkf_facade = LibresFacade(config) storage = open_storage(enkf_facade.enspath) experiment_json = experiments.get_experiments(storage=storage) ensemble_id_default = None @@ -164,8 +162,7 @@ def test_direct_dark_performance_with_storage( with template_config["folder"].as_cwd(): config = ErtConfig.from_file("poly.ert") - ert = EnKFMain(config) - enkf_facade = LibresFacade(ert) + enkf_facade = LibresFacade(config) storage = open_storage(enkf_facade.enspath) experiment_json = experiments.get_experiments(storage=storage) ensemble_id_default = None diff --git a/tests/unit_tests/all/plugins/test_export_misfit.py b/tests/unit_tests/all/plugins/test_export_misfit.py index 7e5a28aa015..56f0a6c7758 100644 --- a/tests/unit_tests/all/plugins/test_export_misfit.py +++ b/tests/unit_tests/all/plugins/test_export_misfit.py @@ -15,9 +15,7 @@ reason="https://github.com/equinor/ert/issues/7533", ) def test_export_misfit(snake_oil_case_storage, snake_oil_default_storage, snapshot): - ExportMisfitDataJob( - snake_oil_case_storage, storage=None, ensemble=snake_oil_default_storage - ).run() + ExportMisfitDataJob().run(snake_oil_case_storage, snake_oil_default_storage, []) result = pd.read_hdf("misfit.hdf").round(10) snapshot.assert_match( result.to_csv(), @@ -27,7 +25,7 @@ def test_export_misfit(snake_oil_case_storage, snake_oil_default_storage, snapsh def test_export_misfit_no_responses_in_storage(poly_case, new_ensemble): with pytest.raises(StorageError, match="No responses loaded"): - ExportMisfitDataJob(poly_case, storage=None, ensemble=new_ensemble).run() + ExportMisfitDataJob().run(poly_case, new_ensemble, []) def test_export_misfit_data_job_is_loaded(): diff --git a/tests/unit_tests/all/plugins/test_export_runpath.py b/tests/unit_tests/all/plugins/test_export_runpath.py index 6f0a30a7447..8fc0cb112c9 100644 --- a/tests/unit_tests/all/plugins/test_export_runpath.py +++ b/tests/unit_tests/all/plugins/test_export_runpath.py @@ -3,27 +3,16 @@ import pytest -from ert.enkf_main import EnKFMain from ert.runpaths import Runpaths from ert.shared.hook_implementations.workflows.export_runpath import ExportRunpathJob from ert.shared.plugins import ErtPluginManager -from ert.storage import open_storage @pytest.fixture def snake_oil_export_runpath_job(setup_case): - ert_config = setup_case("snake_oil", "snake_oil.ert") - ert = EnKFMain(ert_config) - with open_storage(ert_config.ens_path, mode="w") as storage: - yield ExportRunpathJob(ert, storage) - - -def test_export_runpath_number_of_realizations(snake_oil_export_runpath_job): - assert snake_oil_export_runpath_job.number_of_realizations == 25 - - -def test_export_runpath_number_of_iterations(snake_oil_export_runpath_job): - assert snake_oil_export_runpath_job.number_of_iterations == 4 + setup_case("snake_oil", "snake_oil.ert") + plugin = ExportRunpathJob() + yield plugin @dataclass @@ -35,40 +24,43 @@ class WritingSetup: @pytest.fixture def writing_setup(setup_case): with patch.object(Runpaths, "write_runpath_list") as write_mock: - ert_config = setup_case("snake_oil", "snake_oil.ert") - ert = EnKFMain(ert_config) - yield WritingSetup(write_mock, ExportRunpathJob(ert, None)) + config = setup_case("snake_oil", "snake_oil.ert") + yield WritingSetup(write_mock, ExportRunpathJob()), config -def test_export_runpath_no_parameters(writing_setup): - writing_setup.export_job.run() +def test_export_runpath_empty_range(writing_setup): + writing_setup, config = writing_setup + writing_setup.export_job.run(config, []) writing_setup.write_mock.assert_called_with( [0], - list(range(writing_setup.export_job.number_of_realizations)), + list(range(25)), ) def test_export_runpath_star_parameter(writing_setup): - writing_setup.export_job.run("* | *") + writing_setup, config = writing_setup + writing_setup.export_job.run(config, ["* | *"]) writing_setup.write_mock.assert_called_with( - list(range(writing_setup.export_job.number_of_iterations)), - list(range(writing_setup.export_job.number_of_realizations)), + list(range(4)), + list(range(25)), ) def test_export_runpath_range_parameter(writing_setup): - writing_setup.export_job.run("* | 1-2") + writing_setup, config = writing_setup + writing_setup.export_job.run(config, ["* | 1-2"]) writing_setup.write_mock.assert_called_with( [1, 2], - list(range(writing_setup.export_job.number_of_realizations)), + list(range(25)), ) def test_export_runpath_comma_parameter(writing_setup): - writing_setup.export_job.run("3,4 | 1-2") + writing_setup, config = writing_setup + writing_setup.export_job.run(config, ["3,4 | 1-2"]) writing_setup.write_mock.assert_called_with( [1, 2], @@ -77,7 +69,8 @@ def test_export_runpath_comma_parameter(writing_setup): def test_export_runpath_combination_parameter(writing_setup): - writing_setup.export_job.run("1,2-3 | 1-2") + writing_setup, config = writing_setup + writing_setup.export_job.run(config, ["1,2-3 | 1-2"]) writing_setup.write_mock.assert_called_with( [1, 2], @@ -86,8 +79,9 @@ def test_export_runpath_combination_parameter(writing_setup): def test_export_runpath_bad_arguments(writing_setup): + writing_setup, config = writing_setup with pytest.raises(ValueError, match="Expected |"): - writing_setup.export_job.run("wat") + writing_setup.export_job.run(config, ["wat"]) def test_export_runpath_job_is_loaded(): diff --git a/tests/unit_tests/cli/test_cli_workflow.py b/tests/unit_tests/cli/test_cli_workflow.py index 90606630f5d..da4c6e2000c 100644 --- a/tests/unit_tests/cli/test_cli_workflow.py +++ b/tests/unit_tests/cli/test_cli_workflow.py @@ -5,7 +5,6 @@ from ert.cli.workflow import execute_workflow from ert.config import ErtConfig -from ert.enkf_main import EnKFMain from ert.shared.plugins.plugin_manager import ErtPluginContext @@ -20,7 +19,6 @@ def test_executing_workflow(storage): file_handle.write("LOAD_WORKFLOW test_wf") rc = ErtConfig.from_file(config_file) - ert = EnKFMain(rc) args = Namespace(name="test_wf") - execute_workflow(ert, storage, args.name) + execute_workflow(rc, storage, args.name) assert os.path.isfile(".ert_runpath_list") diff --git a/tests/unit_tests/cli/test_model_hook_order.py b/tests/unit_tests/cli/test_model_hook_order.py index c5627d6beae..3ea29252ee3 100644 --- a/tests/unit_tests/cli/test_model_hook_order.py +++ b/tests/unit_tests/cli/test_model_hook_order.py @@ -43,10 +43,11 @@ def test_hook_call_order_ensemble_smoother(monkeypatch): The goal of this test is to assert that the hook call order is the same across different models. """ - ert_mock = MagicMock() + run_wfs_mock = MagicMock() monkeypatch.setattr(ensemble_smoother, "sample_prior", MagicMock()) monkeypatch.setattr(ensemble_smoother, "smoother_update", MagicMock()) monkeypatch.setattr(base_run_model, "LibresFacade", MagicMock()) + monkeypatch.setattr(base_run_model.BaseRunModel, "run_workflows", run_wfs_mock) minimum_args = ESRunArguments( random_seed=None, @@ -67,14 +68,13 @@ def test_hook_call_order_ensemble_smoother(monkeypatch): MagicMock(), MagicMock(), ) - test_class.ert = ert_mock test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) test_class.run_experiment(MagicMock()) expected_calls = [ call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER ] - assert ert_mock.runWorkflows.mock_calls == expected_calls + assert run_wfs_mock.mock_calls == expected_calls @pytest.mark.usefixtures("patch_base_run_model") @@ -96,11 +96,12 @@ def test_hook_call_order_es_mda(monkeypatch): stop_long_running=False, experiment_name="no-name", ) + run_wfs_mock = MagicMock() monkeypatch.setattr(multiple_data_assimilation, "sample_prior", MagicMock()) monkeypatch.setattr(multiple_data_assimilation, "smoother_update", MagicMock()) monkeypatch.setattr(base_run_model, "LibresFacade", MagicMock()) + monkeypatch.setattr(base_run_model.BaseRunModel, "run_workflows", run_wfs_mock) - ert_mock = MagicMock() ens_mock = MagicMock() ens_mock.iteration = 0 storage_mock = MagicMock() @@ -114,15 +115,13 @@ def test_hook_call_order_es_mda(monkeypatch): update_settings=MagicMock(), status_queue=MagicMock(), ) - ert_mock.runWorkflows = MagicMock() - test_class.ert = ert_mock test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) test_class.run_experiment(MagicMock()) expected_calls = [ call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER ] - assert ert_mock.runWorkflows.mock_calls == expected_calls + assert run_wfs_mock.mock_calls == expected_calls @pytest.mark.usefixtures("patch_base_run_model") @@ -131,9 +130,10 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch): The goal of this test is to assert that the hook call order is the same across different models. """ - ert_mock = MagicMock() + run_wfs_mock = MagicMock() monkeypatch.setattr(iterated_ensemble_smoother, "sample_prior", MagicMock()) monkeypatch.setattr(base_run_model, "LibresFacade", MagicMock()) + monkeypatch.setattr(base_run_model.BaseRunModel, "run_workflows", run_wfs_mock) minimum_args = SIESRunArguments( random_seed=None, @@ -156,7 +156,6 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch): MagicMock(), ) test_class.run_ensemble_evaluator = MagicMock(return_value=[0]) - test_class.ert = ert_mock # Mock the return values of iterative_smoother_update # Mock the iteration property of IteratedEnsembleSmoother @@ -173,4 +172,4 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch): expected_calls = [ call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER ] - assert ert_mock.runWorkflows.mock_calls == expected_calls + assert run_wfs_mock.mock_calls == expected_calls diff --git a/tests/unit_tests/cli/test_run_context.py b/tests/unit_tests/cli/test_run_context.py index 252c52b9001..a06370d9997 100644 --- a/tests/unit_tests/cli/test_run_context.py +++ b/tests/unit_tests/cli/test_run_context.py @@ -36,7 +36,6 @@ def test_that_all_iterations_gets_correct_name_and_iteration_number( monkeypatch.setattr(MultipleDataAssimilation, "setPhase", MagicMock()) monkeypatch.setattr(MultipleDataAssimilation, "set_env_key", MagicMock()) monkeypatch.setattr(multiple_data_assimilation, "smoother_update", MagicMock()) - monkeypatch.setattr(base_run_model, "EnKFMain", MagicMock()) test_class = MultipleDataAssimilation( minimum_args, diff --git a/tests/unit_tests/gui/conftest.py b/tests/unit_tests/gui/conftest.py index 8382efe625e..4ea7ea0e882 100644 --- a/tests/unit_tests/gui/conftest.py +++ b/tests/unit_tests/gui/conftest.py @@ -19,7 +19,6 @@ from qtpy.QtWidgets import QApplication, QComboBox, QMessageBox, QPushButton, QWidget from ert.config import ErtConfig -from ert.enkf_main import EnKFMain from ert.ensemble_evaluator.snapshot import ( ForwardModel, RealizationSnapshot, @@ -104,7 +103,6 @@ def _open_main_window( path, ) -> Generator[Tuple[ErtMainWindow, Storage, ErtConfig], None, None]: config = ErtConfig.from_file(path / "poly.ert") - poly_case = EnKFMain(config) args_mock = Mock() args_mock.config = "poly.ert" @@ -113,7 +111,7 @@ def _open_main_window( # RuntimeError: wrapped C/C++ object of type GUILogHandler handler = GUILogHandler() with open_storage(config.ens_path, mode="w") as storage: - gui = _setup_main_window(poly_case, args_mock, handler, storage) + gui = _setup_main_window(config, args_mock, handler, storage) yield gui, storage, config gui.close() @@ -307,7 +305,7 @@ def handle_dialog(): list_model = realization_widget._real_view.model() assert ( list_model.rowCount() - == experiment_panel.ert.ert_config.model_config.num_realizations + == experiment_panel.config.model_config.num_realizations ) qtbot.mouseClick(run_dialog.done_button, Qt.LeftButton) diff --git a/tests/unit_tests/gui/plottery/test_plotting_of_snake_oil.py b/tests/unit_tests/gui/plottery/test_plotting_of_snake_oil.py index 077379b38fc..91a41b56a7a 100644 --- a/tests/unit_tests/gui/plottery/test_plotting_of_snake_oil.py +++ b/tests/unit_tests/gui/plottery/test_plotting_of_snake_oil.py @@ -4,7 +4,6 @@ import pytest from qtpy.QtWidgets import QCheckBox -from ert.enkf_main import EnKFMain from ert.gui.main import GUILogHandler, _setup_main_window from ert.gui.tools.plot.data_type_keys_widget import DataTypeKeysWidget from ert.gui.tools.plot.plot_window import ( @@ -20,11 +19,6 @@ from ert.storage import open_storage -@pytest.fixture -def enkf_main_snake_oil(snake_oil_case_storage): - yield EnKFMain(snake_oil_case_storage) - - # Use a fixture for the fligure in order for the lifetime # of the c++ gui element to not go out before mpl_image_compare @pytest.fixture( @@ -37,7 +31,7 @@ def enkf_main_snake_oil(snake_oil_case_storage): ("SNAKE_OIL_PARAM:OP1_OCTAVES", HISTOGRAM), ], ) -def plot_figure(qtbot, enkf_main_snake_oil, request): +def plot_figure(qtbot, snake_oil_case_storage, request): key = request.param[0] plot_name = request.param[1] args_mock = Mock() @@ -45,9 +39,11 @@ def plot_figure(qtbot, enkf_main_snake_oil, request): log_handler = GUILogHandler() with StorageService.init_service( - project=enkf_main_snake_oil.ert_config.ens_path, - ), open_storage(enkf_main_snake_oil.ert_config.ens_path) as storage: - gui = _setup_main_window(enkf_main_snake_oil, args_mock, log_handler, storage) + project=snake_oil_case_storage.ens_path, + ), open_storage(snake_oil_case_storage.ens_path) as storage: + gui = _setup_main_window( + snake_oil_case_storage, args_mock, log_handler, storage + ) qtbot.addWidget(gui) plot_tool = gui.tools["Create plot"] @@ -101,16 +97,18 @@ def test_that_all_snake_oil_visualisations_matches_snapshot(plot_figure): def test_that_all_plotter_filter_boxes_yield_expected_filter_results( - qtbot, enkf_main_snake_oil + qtbot, snake_oil_case_storage ): args_mock = Mock() args_mock.config = "snake_oil.ert" log_handler = GUILogHandler() with StorageService.init_service( - project=enkf_main_snake_oil.ert_config.ens_path, - ), open_storage(enkf_main_snake_oil.ert_config.ens_path) as storage: - gui = _setup_main_window(enkf_main_snake_oil, args_mock, log_handler, storage) + project=snake_oil_case_storage.ens_path, + ), open_storage(snake_oil_case_storage.ens_path) as storage: + gui = _setup_main_window( + snake_oil_case_storage, args_mock, log_handler, storage + ) gui.notifier.set_storage(storage) qtbot.addWidget(gui) diff --git a/tests/unit_tests/gui/run_analysis/test_run_analysis.py b/tests/unit_tests/gui/run_analysis/test_run_analysis.py index d2e9f88e3ed..c206f6a834f 100644 --- a/tests/unit_tests/gui/run_analysis/test_run_analysis.py +++ b/tests/unit_tests/gui/run_analysis/test_run_analysis.py @@ -27,14 +27,14 @@ class MockedQIcon(QIcon): @pytest.fixture -def ert_mock(): - ert_mock = Mock() - ert_mock.ert_config.random_seed = None - return ert_mock +def config_mock(): + config_mock = Mock() + config_mock.random_seed = None + return config_mock @pytest.fixture -def mock_tool(mock_storage, ert_mock): +def mock_tool(mock_storage, config_mock): with patch("ert.gui.tools.run_analysis.run_analysis_tool.QIcon") as rs: rs.return_value = MockedQIcon() (target, source) = mock_storage @@ -44,7 +44,7 @@ def mock_tool(mock_storage, ert_mock): run_widget.target_ensemble.return_value = target.name notifier = Mock(spec_set=ErtNotifier) notifier.storage.create_ensemble.return_value = target - tool = RunAnalysisTool(ert_mock, notifier) + tool = RunAnalysisTool(config_mock, notifier) tool._run_widget = run_widget tool._dialog = Mock(spec_set=StatusDialog) @@ -60,9 +60,9 @@ def mock_storage(storage): @pytest.mark.requires_window_manager -def test_analyse_success(mock_storage, qtbot, ert_mock): +def test_analyse_success(mock_storage, qtbot, config_mock): (target, source) = mock_storage - analyse = Analyse(ert_mock, target, source) + analyse = Analyse(config_mock, target, source) thread = QThread() with qtbot.waitSignals( [analyse.finished, thread.finished], timeout=2000, raising=True diff --git a/tests/unit_tests/gui/simulation/test_run_dialog.py b/tests/unit_tests/gui/simulation/test_run_dialog.py index 7ffbe3ba5ca..0c01e598ce3 100644 --- a/tests/unit_tests/gui/simulation/test_run_dialog.py +++ b/tests/unit_tests/gui/simulation/test_run_dialog.py @@ -11,7 +11,6 @@ import ert from ert.config import ErtConfig -from ert.enkf_main import EnKFMain from ert.ensemble_evaluator import state from ert.ensemble_evaluator.event import ( EndEvent, @@ -385,11 +384,10 @@ def test_that_run_dialog_can_be_closed_while_file_plot_is_open(qtbot: QtBot, sto args_mock.config = str(config_file) ert_config = ErtConfig.from_file(str(config_file)) - enkf_main = EnKFMain(ert_config) with StorageService.init_service( project=os.path.abspath(ert_config.ens_path), ): - gui = _setup_main_window(enkf_main, args_mock, GUILogHandler(), storage) + gui = _setup_main_window(ert_config, args_mock, GUILogHandler(), storage) qtbot.addWidget(gui) run_experiment = gui.findChild(QToolButton, name="run_experiment") @@ -551,11 +549,10 @@ def test_that_gui_runs_a_minimal_example(qtbot: QtBot, storage): args_mock.config = config_file ert_config = ErtConfig.from_file(config_file) - enkf_main = EnKFMain(ert_config) with StorageService.init_service( project=os.path.abspath(ert_config.ens_path), ): - gui = _setup_main_window(enkf_main, args_mock, GUILogHandler(), storage) + gui = _setup_main_window(ert_config, args_mock, GUILogHandler(), storage) qtbot.addWidget(gui) run_experiment = gui.findChild(QToolButton, name="run_experiment") @@ -576,7 +573,6 @@ def test_that_exception_in_base_run_model_is_handled(qtbot: QtBot, storage): args_mock.config = config_file ert_config = ErtConfig.from_file(config_file) - enkf_main = EnKFMain(ert_config) with StorageService.init_service( project=os.path.abspath(ert_config.ens_path), ), patch.object( @@ -584,7 +580,7 @@ def test_that_exception_in_base_run_model_is_handled(qtbot: QtBot, storage): "run_experiment", MagicMock(side_effect=ValueError("I failed :(")), ): - gui = _setup_main_window(enkf_main, args_mock, GUILogHandler(), storage) + gui = _setup_main_window(ert_config, args_mock, GUILogHandler(), storage) qtbot.addWidget(gui) run_experiment = gui.findChild(QToolButton, name="run_experiment") diff --git a/tests/unit_tests/gui/simulation/test_run_path_dialog.py b/tests/unit_tests/gui/simulation/test_run_path_dialog.py index fd3d77b0c76..d59b5b67637 100644 --- a/tests/unit_tests/gui/simulation/test_run_path_dialog.py +++ b/tests/unit_tests/gui/simulation/test_run_path_dialog.py @@ -8,7 +8,6 @@ from qtpy.QtWidgets import QComboBox, QMessageBox, QToolButton, QWidget from ert.config import ErtConfig -from ert.enkf_main import EnKFMain from ert.gui.main import _setup_main_window from ert.gui.main_window import ErtMainWindow from ert.gui.simulation.ensemble_experiment_panel import EnsembleExperimentPanel @@ -61,9 +60,7 @@ def test_run_path_deleted_error( with StorageService.init_service( project=os.path.abspath(snake_oil_case.ens_path), ), open_storage(snake_oil_case.ens_path, mode="w") as storage: - gui = _setup_main_window( - EnKFMain(snake_oil_case), args_mock, GUILogHandler(), storage - ) + gui = _setup_main_window(snake_oil_case, args_mock, GUILogHandler(), storage) experiment_panel = gui.findChild(ExperimentPanel) assert isinstance(experiment_panel, ExperimentPanel) @@ -112,9 +109,7 @@ def test_run_path_is_deleted(snake_oil_case_storage: ErtConfig, qtbot: QtBot): with StorageService.init_service( project=os.path.abspath(snake_oil_case.ens_path), ), open_storage(snake_oil_case.ens_path, mode="w") as storage: - gui = _setup_main_window( - EnKFMain(snake_oil_case), args_mock, GUILogHandler(), storage - ) + gui = _setup_main_window(snake_oil_case, args_mock, GUILogHandler(), storage) experiment_panel = gui.findChild(ExperimentPanel) assert isinstance(experiment_panel, ExperimentPanel) @@ -161,9 +156,7 @@ def test_run_path_is_not_deleted(snake_oil_case_storage: ErtConfig, qtbot: QtBot with StorageService.init_service( project=os.path.abspath(snake_oil_case.ens_path), ), open_storage(snake_oil_case.ens_path, mode="w") as storage: - gui = _setup_main_window( - EnKFMain(snake_oil_case), args_mock, GUILogHandler(), storage - ) + gui = _setup_main_window(snake_oil_case, args_mock, GUILogHandler(), storage) experiment_panel = gui.findChild(ExperimentPanel) assert isinstance(experiment_panel, ExperimentPanel) diff --git a/tests/unit_tests/gui/test_full_manual_update_workflow.py b/tests/unit_tests/gui/test_full_manual_update_workflow.py index 6d1e820c0f6..877e9371536 100644 --- a/tests/unit_tests/gui/test_full_manual_update_workflow.py +++ b/tests/unit_tests/gui/test_full_manual_update_workflow.py @@ -108,7 +108,7 @@ def handle_dialog(): assert not all(active_reals) assert active_reals == rangestring_to_mask( experiment_panel.get_experiment_arguments().realizations, - analysis_tool.ert.ert_config.model_config.num_realizations, + analysis_tool.ert_config.model_config.num_realizations, ) # Click start simulation and agree to the message run_experiment = get_child(experiment_panel, QWidget, name="run_experiment") diff --git a/tests/unit_tests/gui/test_main_window.py b/tests/unit_tests/gui/test_main_window.py index e05a3398e22..480066f3044 100644 --- a/tests/unit_tests/gui/test_main_window.py +++ b/tests/unit_tests/gui/test_main_window.py @@ -24,7 +24,6 @@ import ert.gui from ert.config import ErtConfig -from ert.enkf_main import EnKFMain from ert.gui.ertwidgets.analysismodulevariablespanel import AnalysisModuleVariablesPanel from ert.gui.ertwidgets.create_experiment_dialog import CreateExperimentDialog from ert.gui.ertwidgets.customdialog import CustomDialog @@ -175,11 +174,10 @@ def test_that_run_dialog_can_be_closed_after_used_to_open_plots(qtbot, storage): args_mock.config = str(config_file) ert_config = ErtConfig.from_file(str(config_file)) - enkf_main = EnKFMain(ert_config) with StorageService.init_service( project=os.path.abspath(ert_config.ens_path), ): - gui = _setup_main_window(enkf_main, args_mock, GUILogHandler(), storage) + gui = _setup_main_window(ert_config, args_mock, GUILogHandler(), storage) qtbot.addWidget(gui) simulation_mode = get_child(gui, QComboBox, name="experiment_type") run_experiment = get_child(gui, QToolButton, name="run_experiment") @@ -665,11 +663,10 @@ def test_that_gui_plotter_works_when_no_data(qtbot, storage, monkeypatch): args_mock = Mock() args_mock.config = config_file ert_config = ErtConfig.from_file(config_file) - enkf_main = EnKFMain(ert_config) with StorageService.init_service( project=os.path.abspath(ert_config.ens_path), ): - gui = _setup_main_window(enkf_main, args_mock, GUILogHandler(), storage) + gui = _setup_main_window(ert_config, args_mock, GUILogHandler(), storage) qtbot.addWidget(gui) gui.tools["Create plot"].trigger() plot_window = wait_for_child(gui, qtbot, PlotWindow) diff --git a/tests/unit_tests/gui/test_restart_ensemble_experiment.py b/tests/unit_tests/gui/test_restart_ensemble_experiment.py index 0023834e830..bc6ba77a921 100644 --- a/tests/unit_tests/gui/test_restart_ensemble_experiment.py +++ b/tests/unit_tests/gui/test_restart_ensemble_experiment.py @@ -77,8 +77,7 @@ def _evaluate(coeffs, x): assert isinstance(realization_widget, RealizationWidget) list_model = realization_widget._real_view.model() assert ( - list_model.rowCount() - == experiment_panel.ert.ert_config.model_config.num_realizations + list_model.rowCount() == experiment_panel.config.model_config.num_realizations ) # Check we have failed realizations diff --git a/tests/unit_tests/gui/test_restart_no_responses_and_parameters.py b/tests/unit_tests/gui/test_restart_no_responses_and_parameters.py index cf7a3921ecb..bc23f1f060f 100644 --- a/tests/unit_tests/gui/test_restart_no_responses_and_parameters.py +++ b/tests/unit_tests/gui/test_restart_no_responses_and_parameters.py @@ -12,7 +12,6 @@ ) from ert.config import ErtConfig -from ert.enkf_main import EnKFMain from ert.gui.main import _setup_main_window from ert.gui.main_window import ErtMainWindow from ert.gui.simulation.evaluate_ensemble_panel import EvaluateEnsemblePanel @@ -66,7 +65,6 @@ def _open_main_window( fh.writelines(config) config = ErtConfig.from_file(path / "config.ert") - poly_case = EnKFMain(config) args_mock = Mock() args_mock.config = "config.ert" @@ -75,7 +73,7 @@ def _open_main_window( # RuntimeError: wrapped C/C++ object of type GUILogHandler handler = GUILogHandler() with open_storage(config.ens_path, mode="w") as storage: - gui = _setup_main_window(poly_case, args_mock, handler, storage) + gui = _setup_main_window(config, args_mock, handler, storage) yield gui, storage, config gui.close() diff --git a/tests/unit_tests/gui/test_rft_export_plugin.py b/tests/unit_tests/gui/test_rft_export_plugin.py index 078cbc92fda..0da604cae38 100644 --- a/tests/unit_tests/gui/test_rft_export_plugin.py +++ b/tests/unit_tests/gui/test_rft_export_plugin.py @@ -8,7 +8,6 @@ from qtpy.QtWidgets import QMessageBox from ert.config import ErtConfig -from ert.enkf_main import EnKFMain from ert.gui.ertwidgets.customdialog import CustomDialog from ert.gui.ertwidgets.listeditbox import ListEditBox from ert.gui.ertwidgets.pathchooser import PathChooser @@ -85,8 +84,7 @@ def test_rft_csv_export_plugin_exports_rft_data( with StorageService.init_service( project=os.path.abspath(ert_config.ens_path), ), open_storage(ert_config.ens_path, mode="w") as storage: - enkf_main = EnKFMain(ert_config) - gui = _setup_main_window(enkf_main, args, GUILogHandler(), storage) + gui = _setup_main_window(ert_config, args, GUILogHandler(), storage) qtbot.addWidget(gui) add_experiment_manually(qtbot, gui) diff --git a/tests/unit_tests/gui/tools/test_workflow_tool.py b/tests/unit_tests/gui/tools/test_workflow_tool.py index cb65a7469fd..2e3ba7c894c 100644 --- a/tests/unit_tests/gui/tools/test_workflow_tool.py +++ b/tests/unit_tests/gui/tools/test_workflow_tool.py @@ -8,7 +8,6 @@ from qtpy.QtCore import Qt, QTimer from ert.config import ErtConfig -from ert.enkf_main import EnKFMain from ert.gui.ertwidgets.closabledialog import ClosableDialog from ert.gui.main import _setup_main_window from ert.gui.main_window import ErtMainWindow @@ -38,7 +37,6 @@ def _open_main_window( config = ErtConfig.with_plugins( ctx.plugin_manager.forward_model_steps ).from_file(path / "config.ert") - enkf_main = EnKFMain(config) args_mock = Mock() args_mock.config = "config.ert" @@ -47,7 +45,7 @@ def _open_main_window( # RuntimeError: wrapped C/C++ object of type GUILogHandler handler = GUILogHandler() with open_storage(config.ens_path, mode="w") as storage: - gui = _setup_main_window(enkf_main, args_mock, handler, storage) + gui = _setup_main_window(config, args_mock, handler, storage) yield gui, storage, config gui.close() diff --git a/tests/unit_tests/job_queue/test_ert_plugin.py b/tests/unit_tests/job_queue/test_ert_plugin.py index b89750fdb2e..134535685ff 100644 --- a/tests/unit_tests/job_queue/test_ert_plugin.py +++ b/tests/unit_tests/job_queue/test_ert_plugin.py @@ -1,3 +1,6 @@ +import logging +from unittest.mock import MagicMock + import pytest from ert.config import CancelPluginException, ErtPlugin @@ -38,7 +41,7 @@ def getArguments(self, parent=None): def test_simple_ert_plugin(): - simple_plugin = SimplePlugin("ert", storage=None) + simple_plugin = SimplePlugin() arguments = simple_plugin.getArguments() @@ -49,7 +52,7 @@ def test_simple_ert_plugin(): def test_full_ert_plugin(): - plugin = FullPlugin("ert", storage=None) + plugin = FullPlugin() assert plugin.getName() == "FullPlugin" assert plugin.getDescription() == "Fully described!" @@ -60,7 +63,101 @@ def test_full_ert_plugin(): def test_cancel_plugin(): - plugin = CanceledPlugin("ert", storage=None) + plugin = CanceledPlugin() with pytest.raises(CancelPluginException): plugin.getArguments() + + +def test_plugin_with_fixtures(): + class FixturePlugin(ErtPlugin): + def run(self, ert_script): + return ert_script + + plugin = FixturePlugin() + fixture_mock = MagicMock() + assert plugin.initializeAndRun([], [], {"ert_script": fixture_mock}) == fixture_mock + + +def test_plugin_with_missing_arguments(caplog): + class FixturePlugin(ErtPlugin): + def run(self, arg_1, ert_script, fixture_2, arg_2="something"): + pass + + plugin = FixturePlugin() + fixture_mock = MagicMock() + fixture2_mock = MagicMock() + with caplog.at_level(logging.WARNING): + plugin.initializeAndRun( + [], [1, 2], {"ert_script": fixture_mock, "fixture_2": fixture2_mock} + ) + + assert plugin.hasFailed() + log = "\n".join(caplog.messages) + assert "FixturePlugin misconfigured" in log + assert "['arg_1', 'arg_2'] not found in fixtures" in log + + +def test_plugin_with_fixtures_and_enough_arguments(): + class FixturePlugin(ErtPlugin): + def run(self, workflow_args, ert_script): + return workflow_args, ert_script + + plugin = FixturePlugin() + fixture_mock = MagicMock() + assert plugin.initializeAndRun([], [1, 2, 3], {"ert_script": fixture_mock}) == ( + ["1", "2", "3"], + fixture_mock, + ) + + +def test_plugin_with_default_arguments(capsys): + class FixturePlugin(ErtPlugin): + def run(self, ert_script=None): + return ert_script + + plugin = FixturePlugin() + fixture_mock = MagicMock() + assert ( + plugin.initializeAndRun([], [1, 2], {"ert_script": fixture_mock}) + == fixture_mock + ) + + +def test_plugin_with_args(): + class FixturePlugin(ErtPlugin): + def run(self, *args): + return args + + plugin = FixturePlugin() + fixture_mock = MagicMock() + assert plugin.initializeAndRun([], [1, 2], {"ert_script": fixture_mock}) == ( + "1", + "2", + ) + + +def test_plugin_with_args_and_kwargs(): + class FixturePlugin(ErtPlugin): + def run(self, *args, **kwargs): + return args + + plugin = FixturePlugin() + fixture_mock = MagicMock() + assert plugin.initializeAndRun([], [1, 2], {"ert_script": fixture_mock}) == ( + "1", + "2", + ) + + +def test_deprecated_properties(): + class FixturePlugin(ErtPlugin): + def run(self): + pass + + plugin = FixturePlugin() + ert_mock = MagicMock() + ensemble_mock = MagicMock() + plugin.initializeAndRun([], [], {"ert_config": ert_mock, "ensemble": ensemble_mock}) + with pytest.deprecated_call(): + assert (plugin.ert(), plugin.ensemble) == (ert_mock, ensemble_mock) diff --git a/tests/unit_tests/job_queue/test_ert_script.py b/tests/unit_tests/job_queue/test_ert_script.py index ca892ff967a..5f5f447add7 100644 --- a/tests/unit_tests/job_queue/test_ert_script.py +++ b/tests/unit_tests/job_queue/test_ert_script.py @@ -7,14 +7,9 @@ # ruff: noqa: PLR6301 -class ReturnErtScript(ErtScript): - def run(self): - return self.ert() - - class AddScript(ErtScript): - def run(self, arg1, arg2): - return arg1 + arg2 + def run(self, *arg): + return arg[0] + arg[1] class NoneScript(ErtScript): @@ -27,21 +22,15 @@ def run(self): raise UserWarning("Custom user warning") -def test_ert_script_return_ert(): - script = ReturnErtScript("ert", storage=None) - result = script.initializeAndRun([], []) - assert result == "ert" - - def test_failing_ert_script_provide_user_warning(): - script = FailingScript("ert", storage=None) + script = FailingScript() result = script.initializeAndRun([], []) assert script.hasFailed() assert result == "Custom user warning" def test_ert_script_add(): - script = AddScript("ert", storage=None) + script = AddScript() result = script.initializeAndRun([int, int], ["5", "4"]) @@ -66,7 +55,7 @@ def test_ert_script_from_file(): script_object = ErtScript.loadScriptFromFile("subtract_script.py") - script = script_object("ert", storage=None) + script = script_object() result = script.initializeAndRun([int, int], ["1", "2"]) assert result == -1 @@ -80,6 +69,6 @@ def test_ert_script_from_file(): def test_none_ert_script(): # Check if None is not converted to string "None" - script = NoneScript("ert", storage=None) + script = NoneScript() script.initializeAndRun([str], [None]) diff --git a/tests/unit_tests/job_queue/test_workflow_job.py b/tests/unit_tests/job_queue/test_workflow_job.py index 5e96176abee..5beec89aa15 100644 --- a/tests/unit_tests/job_queue/test_workflow_job.py +++ b/tests/unit_tests/job_queue/test_workflow_job.py @@ -33,15 +33,13 @@ def test_arguments(): assert job.max_args == 2 assert job.argument_types() == [float, float] - assert WorkflowJobRunner(job).run(None, None, None, [1, 2.5]) + assert WorkflowJobRunner(job).run([1, 2.5]) with pytest.raises(ValueError, match="requires at least 2 arguments"): - WorkflowJobRunner(job).run(None, None, None, [1]) + WorkflowJobRunner(job).run([1]) with pytest.raises(ValueError, match="can only have 2 arguments"): - WorkflowJobRunner(job).run( - None, None, None, ["x %d %f %d %s", 1, 2.5, True, "y", "nada"] - ) + WorkflowJobRunner(job).run(["x %d %f %d %s", 1, 2.5, True, "y", "nada"]) @pytest.mark.usefixtures("use_tmpdir") @@ -57,7 +55,7 @@ def test_run_external_job(): argTypes = job.argument_types() assert argTypes == [str, str] runner = WorkflowJobRunner(job) - assert runner.run(None, None, None, ["test", "text"]) is None + assert runner.run(["test", "text"]) is None assert runner.stdoutdata() == "Hello World\n" with open("test", "r", encoding="utf-8") as f: @@ -76,7 +74,7 @@ def test_error_handling_external_job(): assert not job.internal job.argument_types() runner = WorkflowJobRunner(job) - assert runner.run(None, None, None, []) is None + assert runner.run([]) is None assert runner.stderrdata().startswith("Traceback") @@ -89,7 +87,7 @@ def test_run_internal_script(): config_file="subtract_script_job", ) - result = WorkflowJobRunner(job).run(None, None, None, ["1", "2"]) + result = WorkflowJobRunner(job).run(["1", "2"]) assert result == -1 diff --git a/tests/unit_tests/job_queue/test_workflow_runner.py b/tests/unit_tests/job_queue/test_workflow_runner.py index ae228d9d8be..4f5ad9fa8bb 100644 --- a/tests/unit_tests/job_queue/test_workflow_runner.py +++ b/tests/unit_tests/job_queue/test_workflow_runner.py @@ -63,7 +63,7 @@ def test_workflow_thread_cancel_external(): assert len(workflow) == 3 - workflow_runner = WorkflowRunner(workflow, ert=None) + workflow_runner = WorkflowRunner(workflow) assert not workflow_runner.isRunning() @@ -94,7 +94,7 @@ def test_workflow_failed_job(): ) assert len(workflow) == 2 - workflow_runner = WorkflowRunner(workflow, ert=None) + workflow_runner = WorkflowRunner(workflow) assert not workflow_runner.isRunning() with patch.object( @@ -123,7 +123,7 @@ def test_workflow_success(): assert len(workflow) == 2 - workflow_runner = WorkflowRunner(workflow, ert=None) + workflow_runner = WorkflowRunner(workflow) assert not workflow_runner.isRunning() with workflow_runner: diff --git a/tests/unit_tests/job_queue/workflow_common.py b/tests/unit_tests/job_queue/workflow_common.py index 9c64932ffea..6683221ef17 100644 --- a/tests/unit_tests/job_queue/workflow_common.py +++ b/tests/unit_tests/job_queue/workflow_common.py @@ -46,8 +46,8 @@ def createErtScriptsJob(): f.write("from ert import ErtScript\n") f.write("\n") f.write("class SubtractScript(ErtScript):\n") - f.write(" def run(self, arg1, arg2):\n") - f.write(" return arg1 - arg2\n") + f.write(" def run(self, *argv):\n") + f.write(" return argv[0] - argv[1]\n") with open("subtract_script_job", "w", encoding="utf-8") as f: f.write("INTERNAL True\n") @@ -68,7 +68,8 @@ def createWaitJob(): # noqa: PLR0915 f.write(" with open(filename, 'w') as f:\n") f.write(" f.write(content)\n") f.write("\n") - f.write(" def run(self, number, wait_time):\n") + f.write(" def run(self, *argv):\n") + f.write(" number, wait_time = argv\n") f.write(" self.dump('wait_started_%d' % number, 'text')\n") f.write(" start = time.time()\n") f.write(" diff = 0\n") diff --git a/tests/unit_tests/simulator/test_simulation_context.py b/tests/unit_tests/simulator/test_simulation_context.py index ac56f54140a..af66ee2b266 100644 --- a/tests/unit_tests/simulator/test_simulation_context.py +++ b/tests/unit_tests/simulator/test_simulation_context.py @@ -1,6 +1,5 @@ import pytest -from ert.enkf_main import EnKFMain from ert.simulator import SimulationContext from tests.utils import wait_until @@ -8,7 +7,6 @@ @pytest.mark.usefixtures("using_scheduler") def test_simulation_context(setup_case, storage): ert_config = setup_case("batch_sim", "sleepy_time.ert") - ert = EnKFMain(ert_config) size = 4 even_mask = [True, False] * (size // 2) @@ -27,8 +25,8 @@ def test_simulation_context(setup_case, storage): ) case_data = [(geo_id, {}) for geo_id in range(size)] - even_ctx = SimulationContext(ert, even_half, even_mask, 0, case_data) - odd_ctx = SimulationContext(ert, odd_half, odd_mask, 0, case_data) + even_ctx = SimulationContext(ert_config, even_half, even_mask, 0, case_data) + odd_ctx = SimulationContext(ert_config, odd_half, odd_mask, 0, case_data) for iens in range(size): if iens % 2 == 0: