diff --git a/src/ert/config/__init__.py b/src/ert/config/__init__.py index f20c2a6ce6b..ec01d6db8b3 100644 --- a/src/ert/config/__init__.py +++ b/src/ert/config/__init__.py @@ -1,5 +1,6 @@ from .analysis_config import AnalysisConfig from .analysis_module import AnalysisModule, ESSettings, IESSettings +from .capture_validation import capture_validation from .enkf_observation_implementation_type import EnkfObservationImplementationType from .ensemble_config import EnsembleConfig from .ert_config import ErtConfig @@ -81,6 +82,7 @@ "WarningInfo", "Workflow", "WorkflowJob", + "capture_validation", "field_transform", "lint_file", ] diff --git a/src/ert/config/capture_validation.py b/src/ert/config/capture_validation.py new file mode 100644 index 00000000000..b57bd1d1700 --- /dev/null +++ b/src/ert/config/capture_validation.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import logging +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Iterator, cast +from warnings import catch_warnings + +from .parsing import ConfigValidationError, ConfigWarning, ErrorInfo, WarningInfo + + +@dataclass +class ValidationMessages: + warnings: list[WarningInfo] = field(default_factory=list) + deprecations: list[WarningInfo] = field(default_factory=list) + errors: list[ErrorInfo] = field(default_factory=list) + + +@contextmanager +def capture_validation() -> Iterator[ValidationMessages]: + logger = logging.getLogger(__name__) + validations = ValidationMessages() + with catch_warnings(record=True) as all_warnings: + try: + yield validations + except ConfigValidationError as err: + validations.errors += err.errors + + for wm in all_warnings: + if issubclass(wm.category, ConfigWarning): + warning = cast(ConfigWarning, wm.message) + if warning.info.is_deprecation: + validations.deprecations.append(warning.info) + else: + validations.warnings.append(warning.info) + else: + logger.warning(str(wm.message)) diff --git a/src/ert/gui/main.py b/src/ert/gui/main.py index c139c27c23e..650af23fd27 100755 --- a/src/ert/gui/main.py +++ b/src/ert/gui/main.py @@ -3,10 +3,9 @@ import logging import os import sys -import warnings import webbrowser from signal import SIG_DFL, SIGINT, signal -from typing import Optional, Tuple, cast +from typing import Optional, Tuple if sys.version_info >= (3, 9): from importlib.resources import files @@ -19,7 +18,11 @@ from qtpy.QtGui import QIcon from qtpy.QtWidgets import QApplication, QWidget -from ert.config import ConfigValidationError, ConfigWarning, ErrorInfo, ErtConfig +from ert.config import ( + ErrorInfo, + ErtConfig, + capture_validation, +) from ert.gui.main_window import ErtMainWindow from ert.gui.simulation import ExperimentPanel from ert.gui.tools.event_viewer import ( @@ -87,74 +90,40 @@ def _start_initial_gui_window( # Create logger inside function to make sure all handlers have been added to # the root-logger. logger = logging.getLogger(__name__) - error_messages = [] - config_warnings = [] - deprecations = [] ert_config = None - with warnings.catch_warnings(record=True) as all_warnings: + with capture_validation() as validation_messages: + ert_dir = os.path.abspath(os.path.dirname(args.config)) + os.chdir(ert_dir) + # Changing current working directory means we need to update + # the config file to be the base name of the original config + args.config = os.path.basename(args.config) + + ert_config = ErtConfig.with_plugins().from_file(args.config) + + local_storage_set_ert_config(ert_config) + if ert_config is not None: try: - ert_dir = os.path.abspath(os.path.dirname(args.config)) - os.chdir(ert_dir) - # Changing current working directory means we need to update - # the config file to be the base name of the original config - args.config = os.path.basename(args.config) - - ert_config = ErtConfig.with_plugins().from_file(args.config) - - local_storage_set_ert_config(ert_config) - except ConfigValidationError as error: - config_warnings = [ - cast(ConfigWarning, w.message).info - for w in all_warnings - if w.category == ConfigWarning - and not cast(ConfigWarning, w.message).info.is_deprecation - ] - deprecations = [ - cast(ConfigWarning, w.message).info - for w in all_warnings - if w.category == ConfigWarning - and cast(ConfigWarning, w.message).info.is_deprecation - ] - error_messages += error.errors - if ert_config is not None: - try: - storage = open_storage(ert_config.ens_path, mode="w") - except ErtStorageException as err: - error_messages.append( - ErrorInfo(f"Error opening storage in ENSPATH: {err}").set_context( - ert_config.ens_path - ) + storage = open_storage(ert_config.ens_path, mode="w") + except ErtStorageException as err: + validation_messages.errors.append( + ErrorInfo(f"Error opening storage in ENSPATH: {err}").set_context( + ert_config.ens_path ) - if error_messages: - logger.info(f"Error in config file shown in gui: {error_messages}") - return ( - Suggestor( - error_messages, - config_warnings, - deprecations, - None, - ( - plugin_manager.get_help_links() - if plugin_manager is not None - else {} - ), - ), - None, ) + if validation_messages.errors: + logger.info(f"Error in config file shown in gui: {validation_messages.errors}") + return ( + Suggestor( + validation_messages.errors, + validation_messages.warnings, + validation_messages.deprecations, + None, + (plugin_manager.get_help_links() if plugin_manager is not None else {}), + ), + None, + ) assert ert_config is not None - config_warnings = [ - cast(ConfigWarning, w.message).info - for w in all_warnings - if w.category == ConfigWarning - and not cast(ConfigWarning, w.message).info.is_deprecation - ] - deprecations = [ - cast(ConfigWarning, w.message).info - for w in all_warnings - if w.category == ConfigWarning - and cast(ConfigWarning, w.message).info.is_deprecation - ] counter_fm_steps = Counter(fms.name for fms in ert_config.forward_model_steps) for fm_step_name, count in counter_fm_steps.items(): @@ -162,17 +131,16 @@ def _start_initial_gui_window( f"Config contains forward model step {fm_step_name} {count} time(s)", ) - for wm in all_warnings: - if wm.category != ConfigWarning: - logger.warning(str(wm.message)) - for msg in deprecations: + for msg in validation_messages.deprecations: logger.info(f"Suggestion shown in gui '{msg}'") - for msg in config_warnings: + for msg in validation_messages.warnings: logger.info(f"Warning shown in gui '{msg}'") + _main_window = _setup_main_window( ert_config, args, log_handler, storage, plugin_manager ) - if deprecations or config_warnings: + + if validation_messages.warnings or validation_messages.deprecations: def continue_action() -> None: _main_window.show() @@ -181,9 +149,9 @@ def continue_action() -> None: _main_window.adjustSize() suggestor = Suggestor( - error_messages, - config_warnings, - deprecations, + validation_messages.errors, + validation_messages.warnings, + validation_messages.deprecations, continue_action, plugin_manager.get_help_links() if plugin_manager is not None else {}, ) diff --git a/tests/ert/unit_tests/config/test_capture_validation.py b/tests/ert/unit_tests/config/test_capture_validation.py new file mode 100644 index 00000000000..1e6c71225ec --- /dev/null +++ b/tests/ert/unit_tests/config/test_capture_validation.py @@ -0,0 +1,34 @@ +from ert.config import ( + ConfigValidationError, + ConfigWarning, + ForwardModelStepWarning, + capture_validation, +) + + +def test_capture_validation_captures_warnings(): + with capture_validation() as validation_messages: + ConfigWarning.warn("Message") + + assert validation_messages.warnings[0].message == "Message" + + +def test_capture_validation_captures_deprecations(): + with capture_validation() as validation_messages: + ConfigWarning.deprecation_warn("Message") + + assert validation_messages.deprecations[0].message == "Message" + + +def test_capture_validation_captures_validation_errors(): + with capture_validation() as validation_messages: + raise ConfigValidationError("Message") + + assert validation_messages.errors[0].message == "Message" + + +def test_capture_validation_captures_plugin_warnings(): + with capture_validation() as validation_messages: + ForwardModelStepWarning.warn("Message") + + assert validation_messages.warnings[0].message == "Message"