From cdeb7fe5da5251705f0701f88ba2126bf04ea780 Mon Sep 17 00:00:00 2001 From: mferrera Date: Thu, 23 Nov 2023 09:27:08 +0100 Subject: [PATCH] CLN: Add types to _roff_parameter --- src/xtgeo/grid3d/_roff_parameter.py | 82 +++++++++++++++++------------ 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/src/xtgeo/grid3d/_roff_parameter.py b/src/xtgeo/grid3d/_roff_parameter.py index f65c771e1..61e6d19c4 100644 --- a/src/xtgeo/grid3d/_roff_parameter.py +++ b/src/xtgeo/grid3d/_roff_parameter.py @@ -1,14 +1,20 @@ from __future__ import annotations import warnings -from collections import OrderedDict, defaultdict +from collections import defaultdict from dataclasses import dataclass +from typing import TYPE_CHECKING, Any import numpy as np import roffio from xtgeo.common.constants import UNDEF_INT_LIMIT, UNDEF_LIMIT +if TYPE_CHECKING: + from xtgeo.common.types import FileLike + + from .grid_property import GridProperty + @dataclass class RoffParameter: @@ -44,7 +50,7 @@ class RoffParameter: code_names: list[str] | None = None code_values: np.ndarray | None = None - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, RoffParameter): return False return ( @@ -56,7 +62,7 @@ def __eq__(self, other): and self.same_codes(other) ) - def same_codes(self, other): + def same_codes(self, other: RoffParameter) -> bool: """ Args: other (RoffParameter): Any roff parameter @@ -78,7 +84,7 @@ def same_codes(self, other): ) @property - def undefined_value(self): + def undefined_value(self) -> int | float: """ Returns: The undefined value for the type of values in the @@ -90,9 +96,10 @@ def undefined_value(self): return -999 if np.issubdtype(self.values.dtype, np.floating): return -999.0 + raise ValueError(f"Parameter values of unsupported type {type(self.values)}") @property - def is_discrete(self): + def is_discrete(self) -> bool: """ Returns: True if the RoffParameter is a discrete type @@ -101,7 +108,7 @@ def is_discrete(self): self.values.dtype, np.integer ) - def xtgeo_codes(self): + def xtgeo_codes(self) -> dict[int, str]: """ Returns: The discrete codes of the parameter in the format of @@ -112,7 +119,7 @@ def xtgeo_codes(self): else: return dict() - def xtgeo_values(self): + def xtgeo_values(self) -> np.ndarray: """ Args: The value to use for undefined. Defaults to that defined by @@ -120,21 +127,17 @@ def xtgeo_values(self): Returns: The values in the format of xtgeo grid property """ - vals = self.values - if isinstance(vals, bytes): - vals = np.ndarray(len(vals), np.uint8, vals) - vals = vals.copy() - vals = np.flip(vals.reshape((self.nx, self.ny, self.nz)), -1) - - if self.is_discrete: - vals = vals.astype(np.int32) + if isinstance(self.values, bytes): + vals: np.ndarray = np.ndarray(len(self.values), np.uint8, self.values) else: - vals = vals.astype(np.float64) + vals = self.values.copy() + vals = np.flip(vals.reshape((self.nx, self.ny, self.nz)), -1) + vals = vals.astype(np.int32) if self.is_discrete else vals.astype(np.float64) return np.ma.masked_values(vals, self.undefined_value) @staticmethod - def from_xtgeo_grid_property(xtgeo_grid_property): + def from_xtgeo_grid_property(xtgeo_grid_property: GridProperty) -> RoffParameter: """ Args: xtgeo_grid_property (xtgeo.GridProperty): Any xtgeo.GridProperty @@ -165,27 +168,31 @@ def from_xtgeo_grid_property(xtgeo_grid_property): values = values.astype(np.float32).filled(-999.0) return RoffParameter( - *xtgeo_grid_property.dimensions, + nx=xtgeo_grid_property.ncol, + ny=xtgeo_grid_property.nrow, + nz=xtgeo_grid_property.nlay, name=xtgeo_grid_property.name, values=np.asarray(np.flip(values, -1).ravel()), code_names=code_names, code_values=code_values, ) - def to_file(self, filelike, roff_format=roffio.Format.BINARY): + def to_file( + self, + filelike: FileLike, + roff_format: roffio.Format = roffio.Format.BINARY, + ) -> None: """ Writes the RoffParameter to a roff file Args: filelike (str or byte stream): The file to write to. roff_format (roffio.Format): The format to write the file in. """ - data = OrderedDict( - { - "filedata": {"filetype": "parameter"}, - "dimensions": {"nX": self.nx, "nY": self.ny, "nZ": self.nz}, - "parameter": {"name": self.name}, - } - ) + data: dict[str, dict[str, Any]] = { + "filedata": {"filetype": "parameter"}, + "dimensions": {"nX": self.nx, "nY": self.ny, "nZ": self.nz}, + "parameter": {"name": self.name}, + } if self.code_names is not None: data["parameter"]["codeNames"] = list(self.code_names) if self.code_values is not None: @@ -196,7 +203,7 @@ def to_file(self, filelike, roff_format=roffio.Format.BINARY): roffio.write(filelike, data, roff_format=roff_format) @staticmethod - def from_file(filelike, name=None): + def from_file(filelike: FileLike, name: str | None = None) -> RoffParameter: """ Read a RoffParameter from a roff file Args: @@ -207,7 +214,7 @@ def from_file(filelike, name=None): The RoffGrid in the roff file. """ - def should_skip_parameter(tag, key): + def should_skip_parameter(tag: str, key: str) -> bool: if tag == "parameter" and key[0] == "name": if name is None or key[1] == name: return False @@ -280,11 +287,18 @@ def should_skip_parameter(tag, key): f" have filetype parameter or grid, found {filetype}" ) + roff: dict[str, Any] = {} + for tag, tag_keys in translate_kws.items(): + for key, translated in tag_keys.items(): + if found[tag][key] is not None: + roff[translated] = found[tag][key] + return RoffParameter( - **{ - translated: found[tag][key] - for tag, tag_keys in translate_kws.items() - for key, translated in tag_keys.items() - if found[tag][key] is not None - } + nx=roff["nx"], + ny=roff["ny"], + nz=roff["nz"], + name=roff["name"], + values=roff["values"], + code_names=roff.get("code_names", None), + code_values=roff.get("code_values", None), )