Skip to content

Commit

Permalink
CLN: Add typing to roff_grid
Browse files Browse the repository at this point in the history
  • Loading branch information
JB Lovland authored and janbjorge committed Nov 1, 2023
1 parent 8ddc8b3 commit 69d8caa
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ examples/*
!examples/*.py
!examples/run_examples.sh
_tmp_*
.dmypy.json

# Translations
*.mo
Expand Down
73 changes: 49 additions & 24 deletions src/xtgeo/grid3d/_roff_grid.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
from __future__ import annotations

import pathlib
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from typing import Optional
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
MutableMapping,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np
import roffio

import xtgeo.cxtgeo._cxtgeo as _cxtgeo

if TYPE_CHECKING:
from xtgeo.grid3d import Grid


@dataclass
class RoffGrid:
Expand Down Expand Up @@ -106,15 +122,15 @@ class RoffGrid:
yscale: float = 1.0
zscale: float = -1.0

def __post_init__(self):
def __post_init__(self) -> None:
if self.active is None:
self.active = np.ones(self.nx * self.ny * self.nz, dtype=np.bool_)
if self.split_enz is None:
self.split_enz = np.ones(
self.nx * self.ny * self.nz, dtype=np.uint8
).tobytes()

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, RoffGrid):
return False
return (
Expand All @@ -135,13 +151,13 @@ def __eq__(self, other):
)

@property
def num_nodes(self):
def num_nodes(self) -> int:
"""
The number of nodes in the grid, ie. the size of split_enz.
"""
return (self.nx + 1) * (self.ny + 1) * (self.nz + 1)

def _create_lookup(self):
def _create_lookup(self) -> None:
if not hasattr(self, "_lookup"):
n = self.num_nodes
self._lookup = np.zeros(n + 1, dtype=np.int32)
Expand All @@ -151,7 +167,7 @@ def _create_lookup(self):
else:
self._lookup[i + 1] = 1 + self._lookup[i]

def z_value(self, node):
def z_value(self, node: Tuple[int, int, int]) -> np.ndarray:
"""
Gives the 8 z values for any given node for
adjacent cells in the order:
Expand Down Expand Up @@ -213,7 +229,7 @@ def z_value(self, node):
else:
raise ValueError("Only split types 1, 2, 4 and 8 are supported!")

def xtgeo_coord(self):
def xtgeo_coord(self) -> np.ndarray:
"""
Returns:
The coordinates of nodes in the format of xtgeo.Grid.coordsv
Expand All @@ -226,23 +242,25 @@ def xtgeo_coord(self):
coordsv *= scale
return coordsv.reshape((self.nx + 1, self.ny + 1, 6)).astype(np.float64)

def xtgeo_actnum(self):
def xtgeo_actnum(self) -> np.ndarray:
"""
Returns:
The active field in the format of xtgeo.Grid.actnumsv
"""
assert self.active is not None
actnum = self.active.reshape((self.nx, self.ny, self.nz))
actnum = np.flip(actnum, -1)
return actnum.astype(np.int32)

def xtgeo_zcorn(self):
def xtgeo_zcorn(self) -> np.ndarray:
"""
Returns:
The z values for nodes in the format of xtgeo.Grid.zcornsv
"""
zcornsv = np.zeros(
(self.nx + 1) * (self.ny + 1) * (self.nz + 1) * 4, dtype=np.float32
)
assert self.split_enz is not None
retval = _cxtgeo.grd3d_roff2xtgeo_splitenz(
int(self.nz + 1),
float(self.zoffset),
Expand Down Expand Up @@ -275,7 +293,7 @@ def xtgeo_zcorn(self):
else:
raise ValueError(f"Unknown error {retval} occurred")

def xtgeo_subgrids(self):
def xtgeo_subgrids(self) -> Optional[OrderedDict[str, range]]:
"""
Returns:
The z values for nodes in the format of xtgeo.Grid.zcornsv
Expand All @@ -290,7 +308,9 @@ def xtgeo_subgrids(self):
return result

@staticmethod
def _from_xtgeo_subgrids(xtgeo_subgrids):
def _from_xtgeo_subgrids(
xtgeo_subgrids: MutableMapping[str, Union[range, Sequence]]
) -> np.ndarray:
"""
Args:
A xtgeo.Grid._subgrids dictionary
Expand All @@ -300,7 +320,7 @@ def _from_xtgeo_subgrids(xtgeo_subgrids):
if xtgeo_subgrids is None:
return None
subgrids = []
for key, value in xtgeo_subgrids.items():
for _, value in xtgeo_subgrids.items():
if isinstance(value, range):
subgrids.append(value.stop - value.start)
elif value != list(range(value[0], value[-1] + 1)):
Expand All @@ -312,7 +332,7 @@ def _from_xtgeo_subgrids(xtgeo_subgrids):
return np.array(subgrids, dtype=np.int32)

@staticmethod
def from_xtgeo_grid(xtgeo_grid):
def from_xtgeo_grid(xtgeo_grid: Grid) -> RoffGrid:
"""
Args:
An xtgeo.Grid
Expand All @@ -334,13 +354,17 @@ def from_xtgeo_grid(xtgeo_grid):

return RoffGrid(nx, ny, nz, corner_lines, zvals, split_enz, active, subgrids)

def to_file(self, filelike, roff_format=roffio.Format.BINARY):
def to_file(
self,
filelike: Union[str, pathlib.Path, IO],
roff_format: roffio.Format = roffio.Format.BINARY,
) -> None:
"""
Writes the RoffGrid to a roff file
Args:
filelike (str or byte stream): The file to write to.
"""
data = {
data: Dict[str, Dict] = {
"filedata": {"filetype": "grid"},
"dimensions": {"nX": self.nx, "nY": self.ny, "nZ": self.nz},
"translate": {
Expand All @@ -364,7 +388,7 @@ def to_file(self, filelike, roff_format=roffio.Format.BINARY):
roffio.write(filelike, data, roff_format=roff_format)

@staticmethod
def from_file(filelike):
def from_file(filelike: Union[str, pathlib.Path, IO]) -> RoffGrid:
"""
Read a RoffGrid from a roff file
Args:
Expand Down Expand Up @@ -434,11 +458,12 @@ def from_file(filelike):
f"File {filelike} did not have filetype set to grid, found {filetype}"
)

return RoffGrid(
**{
translated: found[tag][key]
for tag, tag_keys in translate_kws.items()
for key, translated in tag_keys.items()
if found[tag][key] is not None
}
)
# TODO(JB): This needs more refactoring, tricky to track key-values when
# added to a dict like this. One option is to use a TypedDict.
kwarg = {
translated: found[tag][key]
for tag, tag_keys in translate_kws.items()
for key, translated in tag_keys.items()
if found[tag][key] is not None
}
return RoffGrid(**kwarg) # type: ignore

0 comments on commit 69d8caa

Please sign in to comment.