From 9b16042e3a5202ebbefb16f742a5d249a8ee0b5f Mon Sep 17 00:00:00 2001 From: mferrera Date: Fri, 22 Mar 2024 08:54:27 +0100 Subject: [PATCH] CLN: Separate spec and bbox derivation This further breaks down the object data provider into a more refined API and into more maintainable and testable blocks. An added benefit is more type sanity. --- src/fmu/dataio/_utils.py | 5 + src/fmu/dataio/providers/_objectdata.py | 512 ++++-------------- src/fmu/dataio/providers/_objectdata_base.py | 17 +- src/fmu/dataio/providers/_objectdata_xtgeo.py | 358 ++++++++++++ .../test_objectdataprovider_class.py | 8 +- tests/test_units/test_table.py | 2 +- 6 files changed, 478 insertions(+), 424 deletions(-) create mode 100644 src/fmu/dataio/providers/_objectdata_xtgeo.py diff --git a/src/fmu/dataio/_utils.py b/src/fmu/dataio/_utils.py index 1bd8f8c4e..240a82ed2 100644 --- a/src/fmu/dataio/_utils.py +++ b/src/fmu/dataio/_utils.py @@ -13,6 +13,7 @@ from pathlib import Path from typing import Any, Final, Literal +import numpy as np import pandas as pd import xtgeo import yaml @@ -25,6 +26,10 @@ logger: Final = null_logger(__name__) +def npfloat_to_float(v: Any) -> Any: + return float(v) if isinstance(v, (np.float64, np.float32)) else v + + def detect_inside_rms() -> bool: """Detect if 'truly' inside RMS GUI, where predefined variable project exist. diff --git a/src/fmu/dataio/providers/_objectdata.py b/src/fmu/dataio/providers/_objectdata.py index 3077f7958..b4e038fc1 100644 --- a/src/fmu/dataio/providers/_objectdata.py +++ b/src/fmu/dataio/providers/_objectdata.py @@ -89,21 +89,29 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Final -import numpy as np import pandas as pd import xtgeo from fmu.dataio._definitions import STANDARD_TABLE_INDEX_COLUMNS, ValidFormats from fmu.dataio._logging import null_logger -from fmu.dataio.datastructure.meta import meta, specification +from fmu.dataio.datastructure.meta import specification from ._objectdata_base import ( DerivedObjectDescriptor, ObjectDataProvider, - SpecificationAndBoundingBox, +) +from ._objectdata_xtgeo import ( + CPGridDataProvider, + CPGridPropertyDataProvider, + CubeDataProvider, + PointsDataProvider, + PolygonsDataProvider, + RegularSurfaceDataProvider, ) if TYPE_CHECKING: + import pyarrow + from fmu.dataio.dataio import ExportData from fmu.dataio.types import Inferrable @@ -163,330 +171,27 @@ def objectdata_provider_factory( raise NotImplementedError("This data type is not (yet) supported: ", type(obj)) -def npfloat_to_float(v: Any) -> Any: - return float(v) if isinstance(v, (np.float64, np.float32)) else v - - -@dataclass -class RegularSurfaceDataProvider(ObjectDataProvider): - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data.spec and data.bbox for RegularSurface""" - logger.info("Derive bbox and specs for RegularSurface") - regsurf: xtgeo.RegularSurface = self.obj - required = regsurf.metadata.required - - return SpecificationAndBoundingBox( - spec=specification.SurfaceSpecification( - ncol=npfloat_to_float(required["ncol"]), - nrow=npfloat_to_float(required["nrow"]), - xori=npfloat_to_float(required["xori"]), - yori=npfloat_to_float(required["yori"]), - xinc=npfloat_to_float(required["xinc"]), - yinc=npfloat_to_float(required["yinc"]), - yflip=npfloat_to_float(required["yflip"]), - rotation=npfloat_to_float(required["rotation"]), - undef=1.0e30, - ).model_dump( - mode="json", - exclude_none=True, - ), - bbox=meta.content.BoundingBox3D( - xmin=float(regsurf.xmin), - xmax=float(regsurf.xmax), - ymin=float(regsurf.ymin), - ymax=float(regsurf.ymax), - zmin=float(regsurf.values.min()), - zmax=float(regsurf.values.max()), - ).model_dump( - mode="json", - exclude_none=True, - ), - ) - - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() - return DerivedObjectDescriptor( - subtype="RegularSurface", - classname="surface", - layout="regular", - efolder="maps", - fmt=(fmt := self.dataio.surface_fformat), - spec=spec, - bbox=bbox, - extension=self._validate_get_ext( - fmt, "RegularSurface", ValidFormats().surface - ), - table_index=None, - ) - - @dataclass -class PolygonsDataProvider(ObjectDataProvider): - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data.spec and data.bbox for Polygons""" - logger.info("Derive bbox and specs for Polygons") - poly: xtgeo.Polygons = self.obj - xmin, xmax, ymin, ymax, zmin, zmax = poly.get_boundary() - - return SpecificationAndBoundingBox( - spec=specification.PolygonsSpecification( - npolys=np.unique(poly.get_dataframe(copy=False)[poly.pname].values).size - ).model_dump( - mode="json", - exclude_none=True, - ), - bbox=meta.content.BoundingBox3D( - xmin=float(xmin), - xmax=float(xmax), - ymin=float(ymin), - ymax=float(ymax), - zmin=float(zmin), - zmax=float(zmax), - ).model_dump( - mode="json", - exclude_none=True, - ), - ) - - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() - return DerivedObjectDescriptor( - subtype="Polygons", - classname="polygons", - layout="unset", - efolder="polygons", - fmt=(fmt := self.dataio.polygons_fformat), - extension=self._validate_get_ext(fmt, "Polygons", ValidFormats().polygons), - spec=spec, - bbox=bbox, - table_index=None, - ) - - -@dataclass -class PointsDataProvider(ObjectDataProvider): - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data.spec and data.bbox for Points""" - logger.info("Derive bbox and specs for Points") - pnts: xtgeo.Points = self.obj - df: pd.DataFrame = pnts.get_dataframe(copy=False) - - return SpecificationAndBoundingBox( - spec=specification.PointSpecification( - attributes=list(df.columns[3:]) if len(df.columns) > 3 else None, - size=int(df.size), - ).model_dump( - mode="json", - exclude_none=True, - ), - bbox=meta.content.BoundingBox3D( - xmin=float(df[pnts.xname].min()), - xmax=float(df[pnts.xname].max()), - ymax=float(df[pnts.yname].min()), - ymin=float(df[pnts.yname].max()), - zmin=float(df[pnts.zname].min()), - zmax=float(df[pnts.zname].max()), - ).model_dump( - mode="json", - exclude_none=True, - ), - ) - - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() - return DerivedObjectDescriptor( - subtype="Points", - classname="points", - layout="unset", - efolder="points", - fmt=(fmt := self.dataio.points_fformat), - extension=self._validate_get_ext(fmt, "Points", ValidFormats().points), - spec=spec, - bbox=bbox, - table_index=None, - ) - - -@dataclass -class CubeDataProvider(ObjectDataProvider): - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data.spec and data.bbox Cube""" - logger.info("Derive bbox and specs for Cube") - cube: xtgeo.Cube = self.obj - required = cube.metadata.required - - # current xtgeo is missing xmin, xmax etc attributes for cube, so need - # to compute (simplify when xtgeo has this): - xmin, ymin = 1.0e23, 1.0e23 - xmax, ymax = -xmin, -ymin - - for corner in ((1, 1), (1, cube.nrow), (cube.ncol, 1), (cube.ncol, cube.nrow)): - xco, yco = cube.get_xy_value_from_ij(*corner) - xmin = min(xmin, xco) - xmax = max(xmax, xco) - ymin = min(ymin, yco) - ymax = max(ymax, yco) - - return SpecificationAndBoundingBox( - spec=specification.CubeSpecification( - ncol=npfloat_to_float(required["ncol"]), - nrow=npfloat_to_float(required["nrow"]), - nlay=npfloat_to_float(required["nlay"]), - xori=npfloat_to_float(required["xori"]), - yori=npfloat_to_float(required["yori"]), - zori=npfloat_to_float(required["zori"]), - xinc=npfloat_to_float(required["xinc"]), - yinc=npfloat_to_float(required["yinc"]), - zinc=npfloat_to_float(required["zinc"]), - yflip=npfloat_to_float(required["yflip"]), - zflip=npfloat_to_float(required["zflip"]), - rotation=npfloat_to_float(required["rotation"]), - undef=npfloat_to_float(required["undef"]), - ).model_dump( - mode="json", - exclude_none=True, - ), - bbox=meta.content.BoundingBox3D( - xmin=float(xmin), - xmax=float(xmax), - ymin=float(ymin), - ymax=float(ymax), - zmin=float(cube.zori), - zmax=float(cube.zori + cube.zinc * (cube.nlay - 1)), - ).model_dump( - mode="json", - exclude_none=True, - ), - ) - - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() - return DerivedObjectDescriptor( - subtype="RegularCube", - classname="cube", - layout="regular", - efolder="cubes", - fmt=(fmt := self.dataio.cube_fformat), - extension=self._validate_get_ext(fmt, "RegularCube", ValidFormats().cube), - spec=spec, - bbox=bbox, - table_index=None, - ) - - -@dataclass -class CPGridDataProvider(ObjectDataProvider): - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data.spec and data.bbox CornerPoint Grid geometry""" - logger.info("Derive bbox and specs for Gride (geometry)") - grid: xtgeo.Grid = self.obj - required = grid.metadata.required - - geox: dict = grid.get_geometrics( - cellcenter=False, - allcells=True, - return_dict=True, - ) - - return SpecificationAndBoundingBox( - spec=specification.CPGridSpecification( - ncol=npfloat_to_float(required["ncol"]), - nrow=npfloat_to_float(required["nrow"]), - nlay=npfloat_to_float(required["nlay"]), - xshift=npfloat_to_float(required["xshift"]), - yshift=npfloat_to_float(required["yshift"]), - zshift=npfloat_to_float(required["zshift"]), - xscale=npfloat_to_float(required["xscale"]), - yscale=npfloat_to_float(required["yscale"]), - zscale=npfloat_to_float(required["zscale"]), - ).model_dump( - mode="json", - exclude_none=True, - ), - bbox=meta.content.BoundingBox3D( - xmin=round(float(geox["xmin"]), 4), - xmax=round(float(geox["xmax"]), 4), - ymin=round(float(geox["ymin"]), 4), - ymax=round(float(geox["ymax"]), 4), - zmin=round(float(geox["zmin"]), 4), - zmax=round(float(geox["zmax"]), 4), - ).model_dump( - mode="json", - exclude_none=True, - ), - ) - - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() - return DerivedObjectDescriptor( - subtype="CPGrid", - classname="cpgrid", - layout="cornerpoint", - efolder="grids", - fmt=(fmt := self.dataio.grid_fformat), - extension=self._validate_get_ext(fmt, "CPGrid", ValidFormats().grid), - spec=spec, - bbox=bbox, - table_index=None, - ) - - -@dataclass -class CPGridPropertyDataProvider(ObjectDataProvider): - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data.spec and data.bbox GridProperty""" - logger.info("Derive bbox and specs for GridProperty") - gridprop: xtgeo.GridProperty = self.obj - - return SpecificationAndBoundingBox( - spec=specification.CPGridPropertySpecification( - nrow=gridprop.nrow, - ncol=gridprop.ncol, - nlay=gridprop.nlay, - ).model_dump( - mode="json", - exclude_none=True, - ), - bbox={}, - ) - - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() - return DerivedObjectDescriptor( - subtype="CPGridProperty", - classname="cpgrid_property", - layout="cornerpoint", - efolder="grids", - fmt=(fmt := self.dataio.grid_fformat), - extension=self._validate_get_ext( - fmt, "CPGridProperty", ValidFormats().grid - ), - spec=spec, - bbox=bbox, - table_index=None, - ) +class DataFrameDataProvider(ObjectDataProvider): + obj: pd.DataFrame + def _check_index(self, index: list[str]) -> None: + """Check the table index. + Args: + index (list): list of column names -@dataclass -class DataFrameDataProvider(ObjectDataProvider): - def _get_columns(self) -> list[str]: - """Get the columns from table""" - if isinstance(self.obj, pd.DataFrame): - logger.debug("pandas") - columns = list(self.obj.columns) - else: - logger.debug("arrow") - from pyarrow import Table + Raises: + KeyError: if index contains names that are not in self + """ - assert isinstance(self.obj, Table) - columns = self.obj.column_names - logger.debug("Available columns in table %s ", columns) - return columns + not_founds = (item for item in index if item not in list(self.obj.columns)) + for not_found in not_founds: + raise KeyError(f"{not_found} is not in table") def _derive_index(self) -> list[str]: """Derive table index""" # This could in the future also return context - columns = self._get_columns() + columns = list(self.obj.columns) index = [] if self.dataio.table_index is None: @@ -506,36 +211,25 @@ def _derive_index(self) -> list[str]: self._check_index(index) return index - def _check_index(self, index: list[str]) -> None: - """Check the table index. - Args: - index (list): list of column names - - Raises: - KeyError: if index contains names that are not in self - """ + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for pd.DataFrame.""" + logger.info("Get spec for pd.DataFrame (tables)") - not_founds = (item for item in index if item not in self._get_columns()) - for not_found in not_founds: - raise KeyError(f"{not_found} is not in table") - - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data items for DataFrame.""" - logger.info("Process data metadata for DataFrame (tables)") - assert isinstance(self.obj, pd.DataFrame) - return SpecificationAndBoundingBox( - spec=specification.TableSpecification( - columns=list(self.obj.columns), - size=int(self.obj.size), - ).model_dump( - mode="json", - exclude_none=True, - ), - bbox={}, + return specification.TableSpecification( + columns=list(self.obj.columns), + size=int(self.obj.size), + ).model_dump( + mode="json", + exclude_none=True, ) - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for pd.DataFrame.""" + logger.info("Get bbox for pd.DataFrame (tables)") + return {} + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for pd.DataFrame.""" return DerivedObjectDescriptor( subtype="DataFrame", classname="table", @@ -543,21 +237,28 @@ def _derive_objectdata(self) -> DerivedObjectDescriptor: efolder="tables", fmt=(fmt := self.dataio.table_fformat), extension=self._validate_get_ext(fmt, "DataFrame", ValidFormats().table), - spec=spec, - bbox=bbox, + spec=self.get_spec(), + bbox=self.get_bbox(), table_index=self._derive_index(), ) @dataclass class DictionaryDataProvider(ObjectDataProvider): - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data items for dictionary.""" - logger.info("Process data metadata for dictionary") - return SpecificationAndBoundingBox({}, {}) + obj: dict - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for dict.""" + logger.info("Get spec for dictionary") + return {} + + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for dict.""" + logger.info("Get bbox for dictionary") + return {} + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for dict.""" return DerivedObjectDescriptor( subtype="JSON", classname="dictionary", @@ -565,31 +266,32 @@ def _derive_objectdata(self) -> DerivedObjectDescriptor: efolder="dictionaries", fmt=(fmt := self.dataio.dict_fformat), extension=self._validate_get_ext(fmt, "JSON", ValidFormats().dictionary), - spec=spec, - bbox=bbox, + spec=self.get_spec(), + bbox=self.get_bbox(), table_index=None, ) class ArrowTableDataProvider(ObjectDataProvider): - def _get_columns(self) -> list[str]: - """Get the columns from table""" - if isinstance(self.obj, pd.DataFrame): - logger.debug("pandas") - columns = list(self.obj.columns) - else: - logger.debug("arrow") - from pyarrow import Table + obj: pyarrow.Table - assert isinstance(self.obj, Table) - columns = self.obj.column_names - logger.debug("Available columns in table %s ", columns) - return columns + def _check_index(self, index: list[str]) -> None: + """Check the table index. + Args: + index (list): list of column names + + Raises: + KeyError: if index contains names that are not in self + """ + + not_founds = (item for item in index if item not in self.obj.column_names) + for not_found in not_founds: + raise KeyError(f"{not_found} is not in table") def _derive_index(self) -> list[str]: """Derive table index""" # This could in the future also return context - columns = self._get_columns() + columns = self.obj.column_names index = [] if self.dataio.table_index is None: @@ -609,38 +311,25 @@ def _derive_index(self) -> list[str]: self._check_index(index) return index - def _check_index(self, index: list[str]) -> None: - """Check the table index. - Args: - index (list): list of column names - - Raises: - KeyError: if index contains names that are not in self - """ + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for pyarrow.Table.""" + logger.info("Get spec for pyarrow (tables)") - not_founds = (item for item in index if item not in self._get_columns()) - for not_found in not_founds: - raise KeyError(f"{not_found} is not in table") - - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data items for Arrow table.""" - logger.info("Process data metadata for arrow (tables)") - from pyarrow import Table - - assert isinstance(self.obj, Table) - return SpecificationAndBoundingBox( - spec=specification.TableSpecification( - columns=list(self.obj.column_names), - size=self.obj.num_columns * self.obj.num_rows, - ).model_dump( - mode="json", - exclude_none=True, - ), - bbox={}, + return specification.TableSpecification( + columns=list(self.obj.column_names), + size=self.obj.num_columns * self.obj.num_rows, + ).model_dump( + mode="json", + exclude_none=True, ) - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for pyarrow.Table.""" + logger.info("Get bbox for pyarrow (tables)") + return {} + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data from pyarrow.Table.""" return DerivedObjectDescriptor( table_index=self._derive_index(), subtype="ArrowTable", @@ -649,8 +338,8 @@ def _derive_objectdata(self) -> DerivedObjectDescriptor: efolder="tables", fmt=(fmt := self.dataio.arrow_fformat), extension=self._validate_get_ext(fmt, "ArrowTable", ValidFormats().table), - spec=spec, - bbox=bbox, + spec=self.get_spec(), + bbox=self.get_bbox(), ) @@ -660,15 +349,18 @@ class ExistingDataProvider(ObjectDataProvider): object data from existing metadata, by calling _derive_from_existing, and return before calling them.""" - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: - """Process/collect the data items for dictionary.""" - logger.info("Process data metadata for dictionary") - return SpecificationAndBoundingBox( - self.meta_existing["spec"], self.meta_existing["bbox"] - ) + obj: Any + + def get_spec(self) -> dict[str, Any]: + """Derive data.spec from existing metadata.""" + return self.meta_existing["spec"] + + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox from existing metadata.""" + return self.meta_existing["bbox"] - def _derive_objectdata(self) -> DerivedObjectDescriptor: - spec, bbox = self._derive_spec_and_bbox() + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for existing metadata.""" return DerivedObjectDescriptor( subtype=self.meta_existing["subtype"], classname=self.meta_existing["class"], @@ -676,7 +368,7 @@ def _derive_objectdata(self) -> DerivedObjectDescriptor: efolder=self.efolder, fmt=self.meta_existing["format"], extension=self.extension, - spec=spec, - bbox=bbox, + spec=self.get_spec(), + bbox=self.get_bbox(), table_index=None, ) diff --git a/src/fmu/dataio/providers/_objectdata_base.py b/src/fmu/dataio/providers/_objectdata_base.py index dd33e9cf8..0057392d4 100644 --- a/src/fmu/dataio/providers/_objectdata_base.py +++ b/src/fmu/dataio/providers/_objectdata_base.py @@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Dict, Final, Literal, NamedTuple, Optional, TypeVar +from typing import Any, Dict, Final, Literal, Optional, TypeVar from warnings import warn from fmu.dataio import dataio, types @@ -67,11 +67,6 @@ class DerivedObjectDescriptor: table_index: Optional[list[str]] -class SpecificationAndBoundingBox(NamedTuple): - spec: Dict[str, Any] - bbox: Dict[str, Any] - - @dataclass class TimedataValueLabel: value: str @@ -318,7 +313,7 @@ def derive_metadata(self) -> None: return namedstratigraphy = self._derive_name_stratigraphy() - objres = self._derive_objectdata() + objres = self.get_objectdata() if self.dataio.forcefolder and not self.dataio.forcefolder.startswith("/"): msg = ( f"The standard folder name is overrided from {objres.efolder} to " @@ -385,9 +380,13 @@ def derive_metadata(self) -> None: logger.info("Derive all metadata for data object... DONE") @abstractmethod - def _derive_spec_and_bbox(self) -> SpecificationAndBoundingBox: + def get_spec(self) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def get_bbox(self) -> dict[str, Any]: raise NotImplementedError @abstractmethod - def _derive_objectdata(self) -> DerivedObjectDescriptor: + def get_objectdata(self) -> DerivedObjectDescriptor: raise NotImplementedError diff --git a/src/fmu/dataio/providers/_objectdata_xtgeo.py b/src/fmu/dataio/providers/_objectdata_xtgeo.py new file mode 100644 index 000000000..8d156e993 --- /dev/null +++ b/src/fmu/dataio/providers/_objectdata_xtgeo.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Final + +import numpy as np +import pandas as pd +import xtgeo + +from fmu.dataio._definitions import ValidFormats +from fmu.dataio._logging import null_logger +from fmu.dataio._utils import npfloat_to_float +from fmu.dataio.datastructure.meta import meta, specification + +from ._objectdata_base import ( + DerivedObjectDescriptor, + ObjectDataProvider, +) + +if TYPE_CHECKING: + import pandas as pd + +logger: Final = null_logger(__name__) + + +@dataclass +class RegularSurfaceDataProvider(ObjectDataProvider): + obj: xtgeo.RegularSurface + + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for xtgeo.RegularSurface.""" + logger.info("Get spec for RegularSurface") + + required = self.obj.metadata.required + return specification.SurfaceSpecification( + ncol=npfloat_to_float(required["ncol"]), + nrow=npfloat_to_float(required["nrow"]), + xori=npfloat_to_float(required["xori"]), + yori=npfloat_to_float(required["yori"]), + xinc=npfloat_to_float(required["xinc"]), + yinc=npfloat_to_float(required["yinc"]), + yflip=npfloat_to_float(required["yflip"]), + rotation=npfloat_to_float(required["rotation"]), + undef=1.0e30, + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for xtgeo.RegularSurface.""" + logger.info("Get bbox for RegularSurface") + + return meta.content.BoundingBox3D( + xmin=float(self.obj.xmin), + xmax=float(self.obj.xmax), + ymin=float(self.obj.ymin), + ymax=float(self.obj.ymax), + zmin=float(self.obj.values.min()), + zmax=float(self.obj.values.max()), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for xtgeo.RegularSurface.""" + return DerivedObjectDescriptor( + subtype="RegularSurface", + classname="surface", + layout="regular", + efolder="maps", + fmt=(fmt := self.dataio.surface_fformat), + spec=self.get_spec(), + bbox=self.get_bbox(), + extension=self._validate_get_ext( + fmt, "RegularSurface", ValidFormats().surface + ), + table_index=None, + ) + + +@dataclass +class PolygonsDataProvider(ObjectDataProvider): + obj: xtgeo.Polygons + + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for xtgeo.Polygons.""" + logger.info("Get spec for Polygons") + + return specification.PolygonsSpecification( + npolys=np.unique( + self.obj.get_dataframe(copy=False)[self.obj.pname].values + ).size + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for xtgeo.Polygons""" + logger.info("Get bbox for Polygons") + + xmin, xmax, ymin, ymax, zmin, zmax = self.obj.get_boundary() + return meta.content.BoundingBox3D( + xmin=float(xmin), + xmax=float(xmax), + ymin=float(ymin), + ymax=float(ymax), + zmin=float(zmin), + zmax=float(zmax), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for xtgeo.Polygons.""" + return DerivedObjectDescriptor( + subtype="Polygons", + classname="polygons", + layout="unset", + efolder="polygons", + fmt=(fmt := self.dataio.polygons_fformat), + extension=self._validate_get_ext(fmt, "Polygons", ValidFormats().polygons), + spec=self.get_spec(), + bbox=self.get_bbox(), + table_index=None, + ) + + +@dataclass +class PointsDataProvider(ObjectDataProvider): + obj: xtgeo.Points + + @property + def obj_dataframe(self) -> pd.DataFrame: + """Returns a dataframe of the referenced xtgeo.Points object.""" + return self.obj.get_dataframe(copy=False) + + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for xtgeo.Points.""" + logger.info("Get spec for Points") + + df = self.obj_dataframe + return specification.PointSpecification( + attributes=list(df.columns[3:]) if len(df.columns) > 3 else None, + size=int(df.size), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for xtgeo.Points.""" + logger.info("Get bbox for Points") + + df = self.obj_dataframe + return meta.content.BoundingBox3D( + xmin=float(df[self.obj.xname].min()), + xmax=float(df[self.obj.xname].max()), + ymax=float(df[self.obj.yname].min()), + ymin=float(df[self.obj.yname].max()), + zmin=float(df[self.obj.zname].min()), + zmax=float(df[self.obj.zname].max()), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for xtgeo.Points.""" + return DerivedObjectDescriptor( + subtype="Points", + classname="points", + layout="unset", + efolder="points", + fmt=(fmt := self.dataio.points_fformat), + extension=self._validate_get_ext(fmt, "Points", ValidFormats().points), + spec=self.get_spec(), + bbox=self.get_bbox(), + table_index=None, + ) + + +@dataclass +class CubeDataProvider(ObjectDataProvider): + obj: xtgeo.Cube + + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for xtgeo.Cube.""" + logger.info("Get spec for Cube") + + required = self.obj.metadata.required + return specification.CubeSpecification( + ncol=npfloat_to_float(required["ncol"]), + nrow=npfloat_to_float(required["nrow"]), + nlay=npfloat_to_float(required["nlay"]), + xori=npfloat_to_float(required["xori"]), + yori=npfloat_to_float(required["yori"]), + zori=npfloat_to_float(required["zori"]), + xinc=npfloat_to_float(required["xinc"]), + yinc=npfloat_to_float(required["yinc"]), + zinc=npfloat_to_float(required["zinc"]), + yflip=npfloat_to_float(required["yflip"]), + zflip=npfloat_to_float(required["zflip"]), + rotation=npfloat_to_float(required["rotation"]), + undef=npfloat_to_float(required["undef"]), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for xtgeo.Cube.""" + logger.info("Get bbox for Cube") + + # current xtgeo is missing xmin, xmax etc attributes for cube, so need + # to compute (simplify when xtgeo has this): + xmin, ymin = 1.0e23, 1.0e23 + xmax, ymax = -xmin, -ymin + + for corner in ( + (1, 1), + (1, self.obj.nrow), + (self.obj.ncol, 1), + (self.obj.ncol, self.obj.nrow), + ): + xco, yco = self.obj.get_xy_value_from_ij(*corner) + xmin = min(xmin, xco) + xmax = max(xmax, xco) + ymin = min(ymin, yco) + ymax = max(ymax, yco) + + return meta.content.BoundingBox3D( + xmin=float(xmin), + xmax=float(xmax), + ymin=float(ymin), + ymax=float(ymax), + zmin=float(self.obj.zori), + zmax=float(self.obj.zori + self.obj.zinc * (self.obj.nlay - 1)), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for xtgeo.Cube.""" + return DerivedObjectDescriptor( + subtype="RegularCube", + classname="cube", + layout="regular", + efolder="cubes", + fmt=(fmt := self.dataio.cube_fformat), + extension=self._validate_get_ext(fmt, "RegularCube", ValidFormats().cube), + spec=self.get_spec(), + bbox=self.get_bbox(), + table_index=None, + ) + + +@dataclass +class CPGridDataProvider(ObjectDataProvider): + obj: xtgeo.Grid + + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for xtgeo.Grid.""" + logger.info("Get spec for Grid geometry") + + required = self.obj.metadata.required + return specification.CPGridSpecification( + ncol=npfloat_to_float(required["ncol"]), + nrow=npfloat_to_float(required["nrow"]), + nlay=npfloat_to_float(required["nlay"]), + xshift=npfloat_to_float(required["xshift"]), + yshift=npfloat_to_float(required["yshift"]), + zshift=npfloat_to_float(required["zshift"]), + xscale=npfloat_to_float(required["xscale"]), + yscale=npfloat_to_float(required["yscale"]), + zscale=npfloat_to_float(required["zscale"]), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for xtgeo.Grid.""" + logger.info("Get bbox for Grid geometry") + + geox = self.obj.get_geometrics( + cellcenter=False, + allcells=True, + return_dict=True, + ) + return meta.content.BoundingBox3D( + xmin=round(float(geox["xmin"]), 4), + xmax=round(float(geox["xmax"]), 4), + ymin=round(float(geox["ymin"]), 4), + ymax=round(float(geox["ymax"]), 4), + zmin=round(float(geox["zmin"]), 4), + zmax=round(float(geox["zmax"]), 4), + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for xtgeo.Grid.""" + return DerivedObjectDescriptor( + subtype="CPGrid", + classname="cpgrid", + layout="cornerpoint", + efolder="grids", + fmt=(fmt := self.dataio.grid_fformat), + extension=self._validate_get_ext(fmt, "CPGrid", ValidFormats().grid), + spec=self.get_spec(), + bbox=self.get_bbox(), + table_index=None, + ) + + +@dataclass +class CPGridPropertyDataProvider(ObjectDataProvider): + obj: xtgeo.GridProperty + + def get_spec(self) -> dict[str, Any]: + """Derive data.spec for xtgeo.GridProperty.""" + logger.info("Get spec for GridProperty") + + return specification.CPGridPropertySpecification( + nrow=self.obj.nrow, + ncol=self.obj.ncol, + nlay=self.obj.nlay, + ).model_dump( + mode="json", + exclude_none=True, + ) + + def get_bbox(self) -> dict[str, Any]: + """Derive data.bbox for xtgeo.GridProperty.""" + logger.info("Get bbox for GridProperty") + return {} + + def get_objectdata(self) -> DerivedObjectDescriptor: + """Derive object data for xtgeo.GridProperty.""" + return DerivedObjectDescriptor( + subtype="CPGridProperty", + classname="cpgrid_property", + layout="cornerpoint", + efolder="grids", + fmt=(fmt := self.dataio.grid_fformat), + extension=self._validate_get_ext( + fmt, "CPGridProperty", ValidFormats().grid + ), + spec=self.get_spec(), + bbox=self.get_bbox(), + table_index=None, + ) diff --git a/tests/test_units/test_objectdataprovider_class.py b/tests/test_units/test_objectdataprovider_class.py index d6fbd00e5..c7fb71890 100644 --- a/tests/test_units/test_objectdataprovider_class.py +++ b/tests/test_units/test_objectdataprovider_class.py @@ -55,9 +55,9 @@ def test_objectdata_regularsurface_validate_extension_shall_fail(regsurf, edatao def test_objectdata_regularsurface_spec_bbox(regsurf, edataobj1): """Derive specs and bbox for RegularSurface object.""" - specs, bbox = objectdata_provider_factory( - regsurf, edataobj1 - )._derive_spec_and_bbox() + objdata = objectdata_provider_factory(regsurf, edataobj1) + specs = objdata.get_spec() + bbox = objdata.get_bbox() assert specs["ncol"] == regsurf.ncol assert bbox["xmin"] == 0.0 @@ -67,7 +67,7 @@ def test_objectdata_regularsurface_spec_bbox(regsurf, edataobj1): def test_objectdata_regularsurface_derive_objectdata(regsurf, edataobj1): """Derive other properties.""" - res = objectdata_provider_factory(regsurf, edataobj1)._derive_objectdata() + res = objectdata_provider_factory(regsurf, edataobj1).get_objectdata() assert res.subtype == "RegularSurface" assert res.classname == "surface" diff --git a/tests/test_units/test_table.py b/tests/test_units/test_table.py index 6dba2514d..1c646c739 100644 --- a/tests/test_units/test_table.py +++ b/tests/test_units/test_table.py @@ -138,7 +138,7 @@ def test_table_index_real_summary(edataobj3, drogon_summary): drogon_summary (pd.Dataframe): dataframe with summary data from sumo """ objdata = objectdata_provider_factory(drogon_summary, edataobj3) - res = objdata._derive_objectdata() + res = objdata.get_objectdata() assert res.table_index == ["DATE"], "Incorrect table index "