Skip to content

Commit

Permalink
Remove EnKFMain
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Apr 8, 2024
1 parent 3fdff88 commit 745cda8
Show file tree
Hide file tree
Showing 52 changed files with 374 additions and 410 deletions.
4 changes: 1 addition & 3 deletions src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ert.cli.monitor import Monitor
from ert.cli.workflow import execute_workflow
from ert.config import ErtConfig, QueueSystem
from ert.enkf_main import EnKFMain
from ert.ensemble_evaluator import EvaluatorServerConfig, EvaluatorTracker
from ert.namespace import Namespace
from ert.storage import open_storage
Expand Down Expand Up @@ -48,7 +47,6 @@ def run_cli(args: Namespace, _: Any = None) -> None:
for job in ert_config.forward_model_list:
logger.info("Config contains forward model job %s", job.name)

ert = EnKFMain(ert_config)
if not ert_config.observations and args.mode not in [
ENSEMBLE_EXPERIMENT_MODE,
TEST_RUN_MODE,
Expand All @@ -73,7 +71,7 @@ def run_cli(args: Namespace, _: Any = None) -> None:
storage = open_storage(ert_config.ens_path, "w")

if args.mode == WORKFLOW_MODE:
execute_workflow(ert, storage, args.name)
execute_workflow(ert_config, storage, args.name)
return

try:
Expand Down
10 changes: 6 additions & 4 deletions src/ert/cli/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
from ert.job_queue import WorkflowRunner

if TYPE_CHECKING:
from ert.enkf_main import EnKFMain
from ert.config import ErtConfig
from ert.storage import Storage


def execute_workflow(ert: EnKFMain, storage: Storage, workflow_name: str) -> None:
def execute_workflow(
ert_config: ErtConfig, storage: Storage, workflow_name: str
) -> None:
logger = logging.getLogger(__name__)
try:
workflow = ert.ert_config.workflows[workflow_name]
workflow = ert_config.workflows[workflow_name]
except KeyError:
msg = "Workflow {} is not in the list of available workflows"
logger.error(msg.format(workflow_name))
return
runner = WorkflowRunner(workflow, ert, storage)
runner = WorkflowRunner(workflow, storage, ert_config=ert_config)
runner.run_blocking()
if not all(v["completed"] for v in runner.workflowReport().values()):
logger.error(f"Workflow {workflow_name} failed!")
4 changes: 3 additions & 1 deletion src/ert/config/ert_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC
from typing import Any, List

Expand All @@ -9,7 +11,7 @@ class CancelPluginException(Exception):


class ErtPlugin(ErtScript, ABC):
def getArguments(self, parent: Any = None) -> List[Any]: # noqa: PLR6301
def getArguments(self, args: List[Any]) -> List[Any]: # noqa: PLR6301
return []

def getName(self) -> str:
Expand Down
38 changes: 10 additions & 28 deletions src/ert/config/ert_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import traceback
from abc import abstractmethod
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Type

if TYPE_CHECKING:
from ert.enkf_main import EnKFMain
from ert.storage import Ensemble, Storage
from typing import Any, Callable, Dict, List, Optional, Type

logger = logging.getLogger(__name__)

Expand All @@ -21,14 +17,7 @@ class ErtScript:

def __init__(
self,
ert: EnKFMain,
storage: Storage,
ensemble: Optional[Ensemble] = None,
) -> None:
self.__ert = ert
self.__storage = storage
self.__ensemble = ensemble

self.__is_cancelled = False
self.__failed = False
self._stdoutdata = ""
Expand All @@ -50,21 +39,9 @@ def stderrdata(self) -> str:
self._stderrdata = self._stderrdata.decode()
return self._stderrdata

def ert(self) -> "EnKFMain":
def ert(self) -> None:
logger.info(f"Accessing EnKFMain from workflow: {self.__class__.__name__}")
return self.__ert

@property
def storage(self) -> Storage:
return self.__storage

@property
def ensemble(self) -> Optional[Ensemble]:
return self.__ensemble

@ensemble.setter
def ensemble(self, ensemble: Ensemble) -> None:
self.__ensemble = ensemble
raise NotImplementedError("The ert() function has been removed")

def isCancelled(self) -> bool:
return self.__is_cancelled
Expand All @@ -85,7 +62,9 @@ def initializeAndRun(
self,
argument_types: List[Type[Any]],
argument_values: List[str],
fixtures: Optional[Dict[str, Any]] = None,
) -> Any:
fixtures = {} if fixtures is None else fixtures
arguments = []
for index, arg_value in enumerate(argument_values):
arg_type = argument_types[index] if index < len(argument_types) else str
Expand All @@ -96,6 +75,9 @@ def initializeAndRun(
arguments.append(None)

try:
for i, val in enumerate(inspect.signature(self.run).parameters):
if val in fixtures:
arguments.insert(i, fixtures[val])
return self.run(*arguments)
except AttributeError as e:
error_msg = str(e)
Expand Down Expand Up @@ -130,7 +112,7 @@ def output_stack_trace(self, error: str = "") -> None:
@staticmethod
def loadScriptFromFile(
path: str,
) -> Callable[["EnKFMain", "Storage"], "ErtScript"]:
) -> Callable[[], "ErtScript"]:
module_name = f"ErtScriptModule_{ErtScript.__module_count}"
ErtScript.__module_count += 1

Expand All @@ -151,7 +133,7 @@ def loadScriptFromFile(
@staticmethod
def __findErtScriptImplementations(
module: ModuleType,
) -> Callable[["EnKFMain", "Storage"], "ErtScript"]:
) -> Callable[[], "ErtScript"]:
result = []
for _, member in inspect.getmembers(
module,
Expand Down
10 changes: 3 additions & 7 deletions src/ert/config/external_ert_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,14 @@
import codecs
import sys
from subprocess import PIPE, Popen
from typing import TYPE_CHECKING, Any, Optional
from typing import Any, Optional

from .ert_script import ErtScript

if TYPE_CHECKING:
from ert.enkf_main import EnKFMain
from ert.storage import Storage


class ExternalErtScript(ErtScript):
def __init__(self, ert: EnKFMain, storage: Storage, executable: str):
super().__init__(ert, storage, None)
def __init__(self, executable: str):
super().__init__()

self.__executable = executable
self.__job: Optional[Popen[bytes]] = None
Expand Down
3 changes: 1 addition & 2 deletions src/ert/dark_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""
Dark Storage is an API towards data provided by the legacy EnKFMain object and
the `storage/` directory.
Dark Storage is an API towards data provided the `storage/` directory.
"""
23 changes: 2 additions & 21 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
from numpy.random import SeedSequence

from .config import ParameterConfig
from .job_queue import WorkflowRunner
from .run_context import RunContext
from .runpaths import Runpaths
from .substitution_list import SubstitutionList

if TYPE_CHECKING:
import numpy.typing as npt

from .config import ErtConfig, HookRuntime
from .storage import Ensemble, Storage
from .config import ErtConfig
from .storage import Ensemble

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,24 +124,6 @@ def _seed_sequence(seed: Optional[int]) -> int:
return int_seed


class EnKFMain:
def __init__(self, config: "ErtConfig", read_only: bool = False) -> None:
self.ert_config = config
self.update_configuration = None

def __repr__(self) -> str:
return f"EnKFMain(size: {self.ert_config.model_config.num_realizations}, config: {self.ert_config})"

def runWorkflows(
self,
runtime: HookRuntime,
storage: Optional[Storage] = None,
ensemble: Optional[Ensemble] = None,
) -> None:
for workflow in self.ert_config.hooked_workflows[runtime]:
WorkflowRunner(workflow, self, storage, ensemble).run_blocking()


def sample_prior(
ensemble: Ensemble,
active_realizations: Iterable[int],
Expand Down
10 changes: 6 additions & 4 deletions src/ert/gui/ertwidgets/summarypanel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Tuple

from qtpy.QtCore import Qt
Expand All @@ -14,7 +16,7 @@
from ert.gui.ertwidgets.models.ertsummary import ErtSummary

if TYPE_CHECKING:
from ert.enkf_main import EnKFMain
from ert.config import ErtConfig


class SummaryTemplate:
Expand Down Expand Up @@ -55,8 +57,8 @@ def getText(self):


class SummaryPanel(QFrame):
def __init__(self, ert: "EnKFMain"):
self.ert = ert
def __init__(self, config: ErtConfig):
self.config = config
QFrame.__init__(self)

self.setMinimumWidth(250)
Expand All @@ -77,7 +79,7 @@ def __init__(self, ert: "EnKFMain"):
self.updateSummary()

def updateSummary(self):
summary = ErtSummary(self.ert.ert_config)
summary = ErtSummary(self.config)

forward_model_list = summary.getForwardModels()
text = SummaryTemplate(f"Jobs ({len(forward_model_list):,})")
Expand Down
30 changes: 16 additions & 14 deletions src/ert/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from qtpy.QtWidgets import QApplication

from ert.config import ConfigValidationError, ConfigWarning, ErtConfig
from ert.enkf_main import EnKFMain
from ert.gui.ertwidgets import SummaryPanel
from ert.gui.main_window import ErtMainWindow
from ert.gui.simulation import SimulationPanel
Expand Down Expand Up @@ -98,7 +97,6 @@ def _start_initial_gui_window(
args.config = os.path.basename(args.config)
ert_config = ErtConfig.from_file(args.config)
local_storage_set_ert_config(ert_config)
ert = EnKFMain(ert_config)
except ConfigValidationError as error:
config_warnings = [
w.message.info
Expand Down Expand Up @@ -153,7 +151,9 @@ def _start_initial_gui_window(
for msg in config_warnings:
logger.info("Warning shown in gui '%s'", msg)
storage = open_storage(ert_config.ens_path, mode="w")
_main_window = _setup_main_window(ert, args, log_handler, storage, plugin_manager)
_main_window = _setup_main_window(
ert_config, args, log_handler, storage, plugin_manager
)
if deprecations or config_warnings:

def continue_action():
Expand Down Expand Up @@ -216,39 +216,41 @@ def _clicked_about_button(about_dialog):


def _setup_main_window(
ert: EnKFMain,
config: ErtConfig,
args: Namespace,
log_handler: GUILogHandler,
storage: Storage,
plugin_manager: Optional[ErtPluginManager] = None,
) -> ErtMainWindow:
# window reference must be kept until app.exec returns:
facade = LibresFacade(ert)
facade = LibresFacade(config)
config_file = args.config
config = ert.ert_config
window = ErtMainWindow(config_file, plugin_manager)
window.notifier.set_storage(storage)
window.setWidget(SimulationPanel(ert, window.notifier, config_file))
window.setWidget(
SimulationPanel(
config, window.notifier, config_file, facade.get_ensemble_size()
)
)
plugin_handler = PluginHandler(
ert,
window.notifier,
[wfj for wfj in ert.ert_config.workflow_jobs.values() if wfj.is_plugin()],
[wfj for wfj in config.workflow_jobs.values() if wfj.is_plugin()],
window,
)

window.addDock(
"Configuration summary", SummaryPanel(ert), area=Qt.BottomDockWidgetArea
"Configuration summary", SummaryPanel(config), area=Qt.BottomDockWidgetArea
)
window.addTool(PlotTool(config_file, window))
window.addTool(ExportTool(ert, window.notifier))
window.addTool(WorkflowsTool(ert, window.notifier))
window.addTool(ExportTool(config, window.notifier))
window.addTool(WorkflowsTool(config, window.notifier))
window.addTool(
ManageExperimentsTool(
config, window.notifier, config.model_config.num_realizations
)
)
window.addTool(PluginsTool(plugin_handler, window.notifier))
window.addTool(RunAnalysisTool(ert, window.notifier))
window.addTool(PluginsTool(plugin_handler, window.notifier, config))
window.addTool(RunAnalysisTool(config, window.notifier))
window.addTool(LoadResultsTool(facade, window.notifier))
event_viewer = EventViewerTool(log_handler)
window.addTool(event_viewer)
Expand Down
Loading

0 comments on commit 745cda8

Please sign in to comment.