Skip to content

Commit

Permalink
Remove use of deprecated types
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Dec 10, 2024
1 parent 68d01c7 commit bb10596
Show file tree
Hide file tree
Showing 39 changed files with 393 additions and 459 deletions.
99 changes: 45 additions & 54 deletions src/_ert/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Annotated, Any, Dict, Final, Literal, Union
from typing import Annotated, Any, Final, Literal, Optional

from pydantic import BaseModel, ConfigDict, Field, TypeAdapter

Expand Down Expand Up @@ -39,12 +39,12 @@ class Id:
ENSEMBLE_SUCCEEDED: Final = "ensemble.succeeded"
ENSEMBLE_CANCELLED: Final = "ensemble.cancelled"
ENSEMBLE_FAILED: Final = "ensemble.failed"
ENSEMBLE_TYPES = Union[
ENSEMBLE_STARTED_TYPE,
ENSEMBLE_FAILED_TYPE,
ENSEMBLE_SUCCEEDED_TYPE,
ENSEMBLE_CANCELLED_TYPE,
]
ENSEMBLE_TYPES = (
ENSEMBLE_STARTED_TYPE
| ENSEMBLE_FAILED_TYPE
| ENSEMBLE_SUCCEEDED_TYPE
| ENSEMBLE_CANCELLED_TYPE
)

EE_SNAPSHOT_TYPE = Literal["ee.snapshot"]
EE_SNAPSHOT_UPDATE_TYPE = Literal["ee.snapshot_update"]
Expand All @@ -64,47 +64,47 @@ class BaseEvent(BaseModel):


class ForwardModelStepBaseEvent(BaseEvent):
ensemble: Union[str, None] = None
ensemble: Optional[str] = None
real: str
fm_step: str


class ForwardModelStepStart(ForwardModelStepBaseEvent):
event_type: Id.FORWARD_MODEL_STEP_START_TYPE = Id.FORWARD_MODEL_STEP_START
std_out: Union[str, None] = None
std_err: Union[str, None] = None
std_out: Optional[str] = None
std_err: Optional[str] = None


class ForwardModelStepRunning(ForwardModelStepBaseEvent):
event_type: Id.FORWARD_MODEL_STEP_RUNNING_TYPE = Id.FORWARD_MODEL_STEP_RUNNING
max_memory_usage: Union[int, None] = None
current_memory_usage: Union[int, None] = None
max_memory_usage: Optional[int] = None
current_memory_usage: Optional[int] = None
cpu_seconds: float = 0.0


class ForwardModelStepSuccess(ForwardModelStepBaseEvent):
event_type: Id.FORWARD_MODEL_STEP_SUCCESS_TYPE = Id.FORWARD_MODEL_STEP_SUCCESS
current_memory_usage: Union[int, None] = None
current_memory_usage: Optional[int] = None


class ForwardModelStepFailure(ForwardModelStepBaseEvent):
event_type: Id.FORWARD_MODEL_STEP_FAILURE_TYPE = Id.FORWARD_MODEL_STEP_FAILURE
error_msg: str
exit_code: Union[int, None] = None
exit_code: Optional[int] = None


class ForwardModelStepChecksum(BaseEvent):
event_type: Id.FORWARD_MODEL_STEP_CHECKSUM_TYPE = Id.FORWARD_MODEL_STEP_CHECKSUM
ensemble: Union[str, None] = None
ensemble: Optional[str] = None
real: str
checksums: Dict[str, Dict[str, Any]]
checksums: dict[str, dict[str, Any]]


class RealizationBaseEvent(BaseEvent):
real: str
ensemble: Union[str, None] = None
queue_event_type: Union[str, None] = None
exec_hosts: Union[str, None] = None
ensemble: Optional[str] = None
queue_event_type: Optional[str] = None
exec_hosts: Optional[str] = None


class RealizationPending(RealizationBaseEvent):
Expand All @@ -121,7 +121,7 @@ class RealizationSuccess(RealizationBaseEvent):

class RealizationFailed(RealizationBaseEvent):
event_type: Id.REALIZATION_FAILURE_TYPE = Id.REALIZATION_FAILURE
message: Union[str, None] = None # Only used for JobState.FAILED
message: Optional[str] = None # Only used for JobState.FAILED


class RealizationUnknown(RealizationBaseEvent):
Expand All @@ -137,7 +137,7 @@ class RealizationTimeout(RealizationBaseEvent):


class EnsembleBaseEvent(BaseEvent):
ensemble: Union[str, None] = None
ensemble: Optional[str] = None


class EnsembleStarted(EnsembleBaseEvent):
Expand Down Expand Up @@ -168,7 +168,7 @@ class EESnapshotUpdate(EnsembleBaseEvent):

class EETerminated(BaseEvent):
event_type: Id.EE_TERMINATED_TYPE = Id.EE_TERMINATED
ensemble: Union[str, None] = None
ensemble: Optional[str] = None


class EEUserCancel(BaseEvent):
Expand All @@ -181,39 +181,30 @@ class EEUserDone(BaseEvent):
monitor: str


FMEvent = Union[
ForwardModelStepStart,
ForwardModelStepRunning,
ForwardModelStepSuccess,
ForwardModelStepFailure,
]
FMEvent = (
ForwardModelStepStart
| ForwardModelStepRunning
| ForwardModelStepSuccess
| ForwardModelStepFailure
)

RealizationEvent = Union[
RealizationPending,
RealizationRunning,
RealizationSuccess,
RealizationFailed,
RealizationTimeout,
RealizationUnknown,
RealizationWaiting,
]
RealizationEvent = (
RealizationPending
| RealizationRunning
| RealizationSuccess
| RealizationFailed
| RealizationTimeout
| RealizationUnknown
| RealizationWaiting
)

EnsembleEvent = Union[
EnsembleStarted, EnsembleSucceeded, EnsembleFailed, EnsembleCancelled
]
EnsembleEvent = EnsembleStarted | EnsembleSucceeded | EnsembleFailed | EnsembleCancelled

EEEvent = Union[EESnapshot, EESnapshotUpdate, EETerminated, EEUserCancel, EEUserDone]
EEEvent = EESnapshot | EESnapshotUpdate | EETerminated | EEUserCancel | EEUserDone

Event = Union[
FMEvent, ForwardModelStepChecksum, RealizationEvent, EEEvent, EnsembleEvent
]
Event = FMEvent | ForwardModelStepChecksum | RealizationEvent | EEEvent | EnsembleEvent

DispatchEvent = Union[
FMEvent,
ForwardModelStepChecksum,
RealizationEvent,
EnsembleEvent,
]
DispatchEvent = FMEvent | ForwardModelStepChecksum | RealizationEvent | EnsembleEvent

_DISPATCH_EVENTS_ANNOTATION = Annotated[
DispatchEvent, Field(discriminator="event_type")
Expand All @@ -226,21 +217,21 @@ class EEUserDone(BaseEvent):
EventAdapter: TypeAdapter[Event] = TypeAdapter(_ALL_EVENTS_ANNOTATION)


def dispatch_event_from_json(raw_msg: Union[str, bytes]) -> DispatchEvent:
def dispatch_event_from_json(raw_msg: str | bytes) -> DispatchEvent:
return DispatchEventAdapter.validate_json(raw_msg)


def event_from_json(raw_msg: Union[str, bytes]) -> Event:
def event_from_json(raw_msg: str | bytes) -> Event:
return EventAdapter.validate_json(raw_msg)


def event_from_dict(dict_msg: Dict[str, Any]) -> Event:
def event_from_dict(dict_msg: dict[str, Any]) -> Event:
return EventAdapter.validate_python(dict_msg)


def event_to_json(event: Event) -> str:
return event.model_dump_json()


def event_to_dict(event: Event) -> Dict[str, Any]:
def event_to_dict(event: Event) -> dict[str, Any]:
return event.model_dump()
3 changes: 1 addition & 2 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
import logging
import ssl
from typing import Any, AnyStr, Optional, Union
from typing import Any, AnyStr, Optional, Self, Union

from typing_extensions import Self
from websockets.asyncio.client import ClientConnection, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
Expand Down
3 changes: 2 additions & 1 deletion src/_ert/forward_model_runner/reporting/message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from datetime import datetime as dt
from typing import TYPE_CHECKING, Dict, Literal, Optional, TypedDict
from typing import TYPE_CHECKING, Dict, Literal, Optional

import psutil
from typing_extensions import TypedDict

if TYPE_CHECKING:
from _ert.forward_model_runner.forward_model_step import ForwardModelStep
Expand Down
40 changes: 17 additions & 23 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,19 @@
Callable,
Generic,
Iterable,
List,
Optional,
Self,
Sequence,
Tuple,
TypeVar,
)

import iterative_ensemble_smoother as ies
import numpy as np
import polars
import psutil
from iterative_ensemble_smoother.experimental import (
AdaptiveESMDA,
)
from typing_extensions import Self
from iterative_ensemble_smoother.experimental import AdaptiveESMDA

from ert.config import (
GenKwConfig,
)
from ert.config import GenKwConfig

from ..config.analysis_config import ObservationGroups, UpdateSettings
from ..config.analysis_module import ESSettings, IESSettings
Expand Down Expand Up @@ -143,8 +137,8 @@ def _load_param_ensemble_array(


def _expand_wildcards(
input_list: npt.NDArray[np.str_], patterns: List[str]
) -> List[str]:
input_list: npt.NDArray[np.str_], patterns: list[str]
) -> list[str]:
"""
Returns a sorted list of unique strings from `input_list` that match any of the specified wildcard patterns.
Expand All @@ -171,14 +165,14 @@ def _load_observations_and_responses(
global_std_scaling: float,
iens_active_index: npt.NDArray[np.int_],
selected_observations: Iterable[str],
auto_scale_observations: Optional[List[ObservationGroups]],
auto_scale_observations: Optional[list[ObservationGroups]],
progress_callback: Callable[[AnalysisEvent], None],
) -> Tuple[
) -> tuple[
npt.NDArray[np.float64],
Tuple[
tuple[
npt.NDArray[np.float64],
npt.NDArray[np.float64],
List[ObservationAndResponseSnapshot],
list[ObservationAndResponseSnapshot],
],
]:
observations_and_responses = ensemble.get_observations_and_responses(
Expand Down Expand Up @@ -331,7 +325,7 @@ def _load_observations_and_responses(

def _split_by_batchsize(
arr: npt.NDArray[np.int_], batch_size: int
) -> List[npt.NDArray[np.int_]]:
) -> list[npt.NDArray[np.int_]]:
"""
Splits an array into sub-arrays of a specified batch size.
Expand Down Expand Up @@ -415,8 +409,8 @@ def _copy_unupdated_parameters(
This is necessary because users can choose not to update parameters but may still want to analyse them.
Parameters:
all_parameter_groups (List[str]): A list of all parameter groups.
updated_parameter_groups (List[str]): A list of parameter groups that have already been updated.
all_parameter_groups (list[str]): A list of all parameter groups.
updated_parameter_groups (list[str]): A list of parameter groups that have already been updated.
iens_active_index (npt.NDArray[np.int_]): An array of indices for the active realizations in the
target ensemble.
source_ensemble (Ensemble): The file system of the source ensemble, from which parameters are copied.
Expand Down Expand Up @@ -451,7 +445,7 @@ def analysis_ES(
source_ensemble: Ensemble,
target_ensemble: Ensemble,
progress_callback: Callable[[AnalysisEvent], None],
auto_scale_observations: Optional[List[ObservationGroups]],
auto_scale_observations: Optional[list[ObservationGroups]],
) -> None:
iens_active_index = np.flatnonzero(ens_mask)

Expand Down Expand Up @@ -530,7 +524,7 @@ def adaptive_localization_progress_callback(

def correlation_callback(
cross_correlations_of_batch: npt.NDArray[np.float64],
cross_correlations_accumulator: List[npt.NDArray[np.float64]],
cross_correlations_accumulator: list[npt.NDArray[np.float64]],
) -> None:
cross_correlations_accumulator.append(cross_correlations_of_batch)

Expand All @@ -551,7 +545,7 @@ def correlation_callback(
progress_callback(AnalysisStatusEvent(msg=log_msg))

start = time.time()
cross_correlations: List[npt.NDArray[np.float64]] = []
cross_correlations: list[npt.NDArray[np.float64]] = []
for param_batch_idx in batches:
X_local = param_ensemble_array[param_batch_idx, :]
if isinstance(config_node, GenKwConfig):
Expand Down Expand Up @@ -631,7 +625,7 @@ def analysis_IES(
target_ensemble: Ensemble,
sies_smoother: Optional[ies.SIES],
progress_callback: Callable[[AnalysisEvent], None],
auto_scale_observations: List[ObservationGroups],
auto_scale_observations: list[ObservationGroups],
sies_step_length: Callable[[int], float],
initial_mask: npt.NDArray[np.bool_],
) -> ies.SIES:
Expand Down Expand Up @@ -830,7 +824,7 @@ def iterative_smoother_update(
rng: Optional[np.random.Generator] = None,
progress_callback: Optional[Callable[[AnalysisEvent], None]] = None,
global_scaling: float = 1.0,
) -> Tuple[SmootherSnapshot, ies.SIES]:
) -> tuple[SmootherSnapshot, ies.SIES]:
if not progress_callback:
progress_callback = noop_progress_callback
if rng is None:
Expand Down
7 changes: 3 additions & 4 deletions src/ert/config/analysis_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import logging
import math
from typing import Optional, Union
from typing import Annotated, Literal

from pydantic import BaseModel, BeforeValidator, ConfigDict, Field
from typing_extensions import Annotated, Literal

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,7 +67,7 @@ class ESSettings(BaseSettings):
] = "exact"
localization: Annotated[bool, Field(title="Adaptive localization")] = False
localization_correlation_threshold: Annotated[
Optional[float],
float | None,
Field(
ge=0.0,
le=1.0,
Expand Down Expand Up @@ -111,4 +110,4 @@ class IESSettings(BaseSettings):
] = DEFAULT_IES_DEC_STEPLENGTH


AnalysisModule = Union[ESSettings, IESSettings]
AnalysisModule = ESSettings | IESSettings
2 changes: 1 addition & 1 deletion src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Dict,
List,
Optional,
Self,
Sequence,
Tuple,
Type,
Expand All @@ -27,7 +28,6 @@
from pydantic import ValidationError as PydanticValidationError
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self

from ert.plugins import ErtPluginManager
from ert.substitutions import Substitutions
Expand Down
Loading

0 comments on commit bb10596

Please sign in to comment.