diff --git a/src/fmu/dataio/_definitions.py b/src/fmu/dataio/_definitions.py index 9dcc7484f..bfbd36fca 100644 --- a/src/fmu/dataio/_definitions.py +++ b/src/fmu/dataio/_definitions.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from enum import Enum, unique -from typing import Final +from typing import Final, Type SCHEMA: Final = ( "https://main-fmu-schemas-prod.radix.equinor.com/schemas/0.8.0/fmu_results.json" @@ -78,36 +78,31 @@ class ValidFormats: @unique -class FmuContext(Enum): - """Use a Enum class for fmu_context entries.""" +class FmuContext(str, Enum): + """ + Use a Enum class for fmu_context entries. + The different entries will impact where data is exported: REALIZATION = "To realization-N/iter_M/share" CASE = "To casename/share, but will also work on project disk" CASE_SYMLINK_REALIZATION = "To case/share, with symlinks on realizations level" PREPROCESSED = "To share/preprocessed; from interactive runs but re-used later" NON_FMU = "Not ran in a FMU setting, e.g. interactive RMS" - @classmethod - def has_key(cls, key: str) -> bool: - return key.upper() in cls._member_names_ + """ + + REALIZATION = "realization" + CASE = "case" + CASE_SYMLINK_REALIZATION = "case_symlink_realization" + PREPROCESSED = "preprocessed" + NON_FMU = "non-fmu" @classmethod - def list_valid(cls) -> dict: - return {member.name: member.value for member in cls} + def list_valid_values(cls) -> list[str]: + return [m.value for m in cls] @classmethod - def get(cls, key: FmuContext | str) -> FmuContext: - """Get the enum member with a case-insensitive key.""" - if isinstance(key, cls): - key_upper = key.name - elif isinstance(key, str): - key_upper = key.upper() - else: - raise ValidationError("The input must be a str or FmuContext instance") - - if not cls.has_key(key_upper): - raise ValidationError( - f"Invalid key <{key_upper}>. Valid keys: {cls.list_valid().keys()}" - ) - - return cls[key_upper] + def _missing_(cls: Type[FmuContext], value: object) -> None: + raise ValueError( + f"Invalid FmuContext {value=}. Valid entries are {cls.list_valid_values()}" + ) diff --git a/src/fmu/dataio/dataio.py b/src/fmu/dataio/dataio.py index a6ec39713..13ebe9b25 100644 --- a/src/fmu/dataio/dataio.py +++ b/src/fmu/dataio/dataio.py @@ -469,7 +469,6 @@ def __post_init__(self) -> None: logger.debug("Global config is %s", prettyprint_dict(self.config)) self._fmurun = FmuEnv.ENSEMBLE_ID.value is not None - self.fmu_context = FmuContext.get(self.fmu_context) # set defaults for mutable keys self.vertical_domain = {"depth": "msl"} @@ -546,17 +545,16 @@ def _validate_content_key(self) -> None: def _validate_fmucontext_key(self) -> None: """Validate the given 'fmu_context' input.""" if isinstance(self.fmu_context, str): - self.fmu_context = FmuContext.get(self.fmu_context) - + self.fmu_context = FmuContext(self.fmu_context.lower()) # fmu_context is only allowed to be preprocessed if not in a fmu run if not self._fmurun and self.fmu_context != FmuContext.PREPROCESSED: logger.warning( "Requested fmu_context is <%s> but since this is detected as a non " "FMU run, the actual context is force set to <%s>", self.fmu_context, - FmuContext.get("non_fmu"), + FmuContext.NON_FMU, ) - self.fmu_context = FmuContext.get("non_fmu") + self.fmu_context = FmuContext.NON_FMU def _update_fmt_flag(self) -> None: # treat special handling of "xtgeo" in format name: @@ -676,9 +674,10 @@ def _check_process_object(self, obj: types.Inferrable) -> None: self._object = obj def _get_fmu_provider(self) -> FmuProvider: + assert isinstance(self.fmu_context, FmuContext) return FmuProvider( model=self.config.get("model", None), - fmu_context=FmuContext.get(self.fmu_context), + fmu_context=self.fmu_context, casepath_proposed=self.casepath or "", include_ertjobs=self.include_ertjobs, forced_realization=self.realization, diff --git a/src/fmu/dataio/datastructure/_internal/internal.py b/src/fmu/dataio/datastructure/_internal/internal.py index 5660c4e61..761f66610 100644 --- a/src/fmu/dataio/datastructure/_internal/internal.py +++ b/src/fmu/dataio/datastructure/_internal/internal.py @@ -12,7 +12,7 @@ from textwrap import dedent from typing import List, Literal, Optional, Union -from fmu.dataio._definitions import SCHEMA, SOURCE, VERSION +from fmu.dataio._definitions import SCHEMA, SOURCE, VERSION, FmuContext from fmu.dataio.datastructure.configuration.global_configuration import ( Model as GlobalConfigurationModel, ) @@ -170,13 +170,7 @@ class PreprocessedInfo(BaseModel): class Context(BaseModel): - stage: Literal[ - "realization", - "case", - "case_symlink_realization", - "preprocessed", - "non_fmu", - ] + stage: FmuContext # Remove the two models below when content is required as input. diff --git a/tests/test_units/test_enum_classes.py b/tests/test_units/test_enum_classes.py index 17cc295ac..c66883d89 100644 --- a/tests/test_units/test_enum_classes.py +++ b/tests/test_units/test_enum_classes.py @@ -6,17 +6,16 @@ def test_fmu_context_validation() -> None: """Test the FmuContext enum class.""" - rel = FmuContext.get("realization") + rel = FmuContext("realization") assert rel.name == "REALIZATION" - with pytest.raises(KeyError, match="Invalid key"): - FmuContext.get("invalid_context") + with pytest.raises(ValueError, match="Invalid FmuContext value='invalid_context'"): + FmuContext("invalid_context") - valid_types = FmuContext.list_valid() - assert list(valid_types.keys()) == [ - "REALIZATION", - "CASE", - "CASE_SYMLINK_REALIZATION", - "PREPROCESSED", - "NON_FMU", + assert FmuContext.list_valid_values() == [ + "realization", + "case", + "case_symlink_realization", + "preprocessed", + "non-fmu", ]