Skip to content

Commit

Permalink
BUG: accept masked arrays as values to Cube()
Browse files Browse the repository at this point in the history
This also contains more checks on `values` input, e.g. if dimensions are corrects
  • Loading branch information
jcrivenaes committed Feb 14, 2024
1 parent f414fe6 commit 83d4437
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 18 deletions.
71 changes: 53 additions & 18 deletions src/xtgeo/cube/cube1.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def _reset(
self._zflip = zflip # currently not in use
self._rotation = rotation

self._values = None
self.values = values # "values" is intentional over "_values"; cf. values()
# input values can be "list-like" or scalar
self._values = self._ensure_correct_values(values)

if ilines is None:
self._ilines = ilines or np.array(range(1, self._ncol + 1), dtype=np.int32)
Expand Down Expand Up @@ -491,7 +491,7 @@ def values(self):

@values.setter
def values(self, values):
self._ensure_correct_values(values)
self._values = self._ensure_correct_values(values)

# =========================================================================
# Describe
Expand Down Expand Up @@ -1101,33 +1101,68 @@ def scan_segy_header(sfile: str, outfile: FileLike = None):
print(line.rstrip("\r\n"))
os.remove(outfile)

def _ensure_correct_values(self, values):
def _ensure_correct_values(
self,
values: None | bool | float | list | tuple | np.ndarray | np.ma.MaskedArray,
) -> np.ndarray:
"""Ensures that values is a 3D numpy (ncol, nrow, nlay), C order.
Args:
values (array-like or scalar): Values to process.
Return:
Nothing, self._values will be updated inplace
values: Values to process.
"""
if values is None or values is False:
self._ensure_correct_values(0.0)
return
return_array = None
if values is None or isinstance(values, bool):
return_array = self._ensure_correct_values(0.0)

if isinstance(values, numbers.Number):
self._values = np.zeros(self.dimensions, dtype=np.float32) + values
self._values = self._values.astype(np.float32) # ensure 32 bit floats
elif isinstance(values, numbers.Number):
array = np.zeros(self.dimensions, dtype=np.float32) + values
return_array = array.astype(np.float32) # ensure 32 bit floats

elif isinstance(values, np.ndarray):
# if the input is a maskedarray; need to convert and fill with zero
if isinstance(values, np.ma.MaskedArray):
warnings.warn(
"Input values is a masked numpy array, and masked nodes "
"will be set to zero in the cube instance.",
UserWarning,
)
values = np.ma.filled(values, fill_value=0)

exp_len = np.prod(self.dimensions)
if (
values.size != exp_len
or values.ndim not in (1, 3)
or values.shape != self.dimensions
):
raise ValueError(
"Input is of wrong shape or dimensions: "
f"{values.shape}, expected {self.dimensions}"
"or ({exp_len},)"
)

values = values.reshape(self.dimensions).astype(np.float32)

if not values.data.c_contiguous:
if not values.flags.c_contiguous:
values = np.ascontiguousarray(values)
self._values = values
return_array = values

elif isinstance(values, (list, tuple)):
self._values = np.array(values, dtype=np.float32).reshape(self.dimensions)
exp_len = int(np.prod(self.dimensions))
if len(values) != exp_len:
raise ValueError(
"The length of the input list or tuple is incorrect"
f"Input length is {len(values)} while expected length is {exp_len}"
)

return_array = np.array(values, dtype=np.float32).reshape(self.dimensions)

else:
raise RuntimeError("Cannot process _ensure_correct_values")
raise ValueError(
f"Cannot process _ensure_correct_values with input values: {values}"
)

if return_array is not None:
return return_array

raise RuntimeError("Unexpected error, return values are None")
42 changes: 42 additions & 0 deletions tests/test_cube/test_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,48 @@ def test_create():
assert xdim == 5, "NX from numpy shape "


@pytest.mark.parametrize(
"input, behaviour",
[
(np.ones((2, 3, 1)), None),
(np.ma.ones((2, 3, 1)), UserWarning),
(np.ones((2, 3, 99)), ValueError),
(np.ones((3, 2, 1)), ValueError),
(np.ones((3, 2)), ValueError),
(np.ones((1, 6)), ValueError),
([1, 2, 3, 4, 5, 6], None),
([1, 2, 3, 4, 5, 6, 7], ValueError),
(99, None),
("a", ValueError),
],
ids=[
"np array",
"masked np array (warn)",
"np array right dims but wrong size (err)",
"np array right dims but flipped row col (err)",
"np array wrong dims as 2D ex 1 (err)",
"np array wrong dims as 2D ex 2 (err)",
"list",
"list_wrong_length (err)",
"scalar",
"letter (err)",
],
)
def test_create_cube_with_values(input, behaviour):
"""Create cube with various input values, both correct and incorrect formats."""

if behaviour is None:
Cube(ncol=2, nrow=3, nlay=1, xinc=1, yinc=1, zinc=1, values=input)

elif behaviour is UserWarning:
with pytest.warns(behaviour):
Cube(ncol=2, nrow=3, nlay=1, xinc=1, yinc=1, zinc=1, values=input)

elif behaviour is ValueError:
with pytest.raises(behaviour):
Cube(ncol=2, nrow=3, nlay=1, xinc=1, yinc=1, zinc=1, values=input)


def test_import_wrong_format(tmp_path):
(tmp_path / "test.EGRID").write_text("hello")
with pytest.raises(ValueError, match="File format"):
Expand Down

0 comments on commit 83d4437

Please sign in to comment.