From dbeaad0fc9c1ebd3b13e02b82e2cbadcdcf90ad9 Mon Sep 17 00:00:00 2001 From: mferrera Date: Tue, 11 Jun 2024 09:41:32 +0200 Subject: [PATCH] MAINT: Make format and extension provider properties --- src/fmu/dataio/_definitions.py | 84 +++---- src/fmu/dataio/aggregation.py | 5 +- src/fmu/dataio/providers/objectdata/_base.py | 79 +++---- .../dataio/providers/objectdata/_faultroom.py | 11 +- .../dataio/providers/objectdata/_provider.py | 15 +- .../dataio/providers/objectdata/_tables.py | 34 ++- src/fmu/dataio/providers/objectdata/_xtgeo.py | 216 ++++++++++-------- src/fmu/dataio/types.py | 12 - .../test_units/test_filedataprovider_class.py | 5 +- .../test_objectdataprovider_class.py | 9 +- 10 files changed, 245 insertions(+), 225 deletions(-) diff --git a/src/fmu/dataio/_definitions.py b/src/fmu/dataio/_definitions.py index 9063f527f..f3bf35eab 100644 --- a/src/fmu/dataio/_definitions.py +++ b/src/fmu/dataio/_definitions.py @@ -2,7 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass, field from enum import Enum, unique from typing import Final, Type @@ -21,52 +20,43 @@ class ConfigurationError(ValueError): pass -@dataclass -class ValidFormats: - surface: dict[str, str] = field( - default_factory=lambda: { - "irap_binary": ".gri", - } - ) - grid: dict[str, str] = field( - default_factory=lambda: { - "hdf": ".hdf", - "roff": ".roff", - } - ) - cube: dict[str, str] = field( - default_factory=lambda: { - "segy": ".segy", - } - ) - table: dict[str, str] = field( - default_factory=lambda: { - "hdf": ".hdf", - "csv": ".csv", - "parquet": ".parquet", - } - ) - polygons: dict[str, str] = field( - default_factory=lambda: { - "hdf": ".hdf", - "csv": ".csv", # columns will be X Y Z, ID - "csv|xtgeo": ".csv", # use default xtgeo columns: X_UTME, ... POLY_ID - "irap_ascii": ".pol", - } - ) - points: dict[str, str] = field( - default_factory=lambda: { - "hdf": ".hdf", - "csv": ".csv", # columns will be X Y Z - "csv|xtgeo": ".csv", # use default xtgeo columns: X_UTME, Y_UTMN, Z_TVDSS - "irap_ascii": ".poi", - } - ) - dictionary: dict[str, str] = field( - default_factory=lambda: { - "json": ".json", - } - ) +class ValidFormats(Enum): + surface = { + "irap_binary": ".gri", + } + + grid = { + "hdf": ".hdf", + "roff": ".roff", + } + + cube = { + "segy": ".segy", + } + + table = { + "hdf": ".hdf", + "csv": ".csv", + "parquet": ".parquet", + } + + polygons = { + "hdf": ".hdf", + "csv": ".csv", # columns will be X Y Z, ID + "csv|xtgeo": ".csv", # use default xtgeo columns: X_UTME, ... POLY_ID + "irap_ascii": ".pol", + } + + points = { + "hdf": ".hdf", + "csv": ".csv", # columns will be X Y Z + "csv|xtgeo": ".csv", # use default xtgeo columns: X_UTME, Y_UTMN, Z_TVDSS + "irap_ascii": ".poi", + } + + dictionary = { + "json": ".json", + } STANDARD_TABLE_INDEX_COLUMNS: Final[dict[str, list[str]]] = { diff --git a/src/fmu/dataio/aggregation.py b/src/fmu/dataio/aggregation.py index 73cc7999e..491bd9f0b 100644 --- a/src/fmu/dataio/aggregation.py +++ b/src/fmu/dataio/aggregation.py @@ -244,8 +244,7 @@ def _generate_aggrd_metadata( } etemp = dataio.ExportData(config=config, name=self.name) - objectdata_provider = objectdata_provider_factory(obj=obj, dataio=etemp) - objdata = objectdata_provider.get_objectdata() + objdata = objectdata_provider_factory(obj=obj, dataio=etemp) template["tracklog"] = [generate_meta_tracklog()[0].model_dump(mode="json")] template["file"] = { @@ -262,7 +261,7 @@ def _generate_aggrd_metadata( template["data"]["name"] = self.name if self.tagname: template["data"]["tagname"] = self.tagname - if bbox := objectdata_provider.get_bbox(): + if bbox := objdata.get_bbox(): template["data"]["bbox"] = bbox.model_dump(mode="json", exclude_none=True) self._metadata = template diff --git a/src/fmu/dataio/providers/objectdata/_base.py b/src/fmu/dataio/providers/objectdata/_base.py index 924c26966..aa6dbedd8 100644 --- a/src/fmu/dataio/providers/objectdata/_base.py +++ b/src/fmu/dataio/providers/objectdata/_base.py @@ -4,10 +4,10 @@ from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Final, TypeVar +from typing import TYPE_CHECKING, Final from warnings import warn -from fmu.dataio._definitions import ConfigurationError +from fmu.dataio._definitions import ConfigurationError, ValidFormats from fmu.dataio._logging import null_logger from fmu.dataio._utils import generate_description from fmu.dataio.datastructure._internal.internal import AllowedContent, UnsetAnyContent @@ -24,20 +24,15 @@ from fmu.dataio.datastructure.meta.content import BoundingBox2D, BoundingBox3D from fmu.dataio.datastructure.meta.enums import FMUClassEnum from fmu.dataio.datastructure.meta.specification import AnySpecification - from fmu.dataio.types import Efolder, Inferrable, Layout, Subtype + from fmu.dataio.types import Efolder, Inferrable, Layout logger: Final = null_logger(__name__) -V = TypeVar("V") - @dataclass class DerivedObjectDescriptor: - subtype: Subtype layout: Layout efolder: Efolder | str - fmt: str - extension: str table_index: list[str] | None @@ -76,8 +71,6 @@ class ObjectDataProvider(Provider): _metadata: dict = field(default_factory=dict) name: str = field(default="") efolder: str = field(default="") - extension: str = field(default="") - fmt: str = field(default="") time0: datetime | None = field(default=None) time1: datetime | None = field(default=None) @@ -87,8 +80,6 @@ def __post_init__(self) -> None: obj_data = self.get_objectdata() self.name = named_stratigraphy.name - self.extension = obj_data.extension - self.fmt = obj_data.fmt self.efolder = obj_data.efolder if self.dataio.forcefolder: @@ -149,6 +140,40 @@ def __post_init__(self) -> None: self._metadata["description"] = generate_description(self.dataio.description) logger.info("Derive all metadata for data object... DONE") + @property + @abstractmethod + def classname(self) -> FMUClassEnum: + raise NotImplementedError + + @property + @abstractmethod + def extension(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def fmt(self) -> str: + raise NotImplementedError + + @abstractmethod + def get_bbox(self) -> BoundingBox2D | BoundingBox3D | None: + raise NotImplementedError + + @abstractmethod + def get_spec(self) -> AnySpecification | None: + raise NotImplementedError + + @abstractmethod + def get_objectdata(self) -> DerivedObjectDescriptor: + raise NotImplementedError + + def get_metadata(self) -> AnyContent | UnsetAnyContent: + return ( + UnsetAnyContent.model_validate(self._metadata) + if self._metadata["content"] == "unset" + else AnyContent.model_validate(self._metadata) + ) + def _get_validated_content(self, content: str | dict | None) -> AllowedContent: """Check content and return a validated model.""" logger.info("Evaluate content") @@ -254,37 +279,13 @@ def _get_timedata(self) -> Time | None: return Time(t0=start, t1=stop) - @property - @abstractmethod - def classname(self) -> FMUClassEnum: - raise NotImplementedError - - @abstractmethod - def get_spec(self) -> AnySpecification | None: - raise NotImplementedError - - @abstractmethod - def get_bbox(self) -> BoundingBox2D | BoundingBox3D | None: - raise NotImplementedError - - @abstractmethod - def get_objectdata(self) -> DerivedObjectDescriptor: - raise NotImplementedError - - def get_metadata(self) -> AnyContent | UnsetAnyContent: - return ( - UnsetAnyContent.model_validate(self._metadata) - if self._metadata["content"] == "unset" - else AnyContent.model_validate(self._metadata) - ) - @staticmethod - def _validate_get_ext(fmt: str, subtype: str, validator: dict[str, V]) -> V: + def _validate_get_ext(fmt: str, validator: ValidFormats) -> str: """Validate that fmt (file format) matches data and return legal extension.""" try: - return validator[fmt] + return validator.value[fmt] except KeyError: raise ConfigurationError( f"The file format {fmt} is not supported. ", - f"Valid {subtype} formats are: {list(validator.keys())}", + f"Valid formats are: {list(validator.value.keys())}", ) diff --git a/src/fmu/dataio/providers/objectdata/_faultroom.py b/src/fmu/dataio/providers/objectdata/_faultroom.py index 3f928126d..a55b1732d 100644 --- a/src/fmu/dataio/providers/objectdata/_faultroom.py +++ b/src/fmu/dataio/providers/objectdata/_faultroom.py @@ -26,6 +26,14 @@ class FaultRoomSurfaceProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.surface + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.dictionary) + + @property + def fmt(self) -> str: + return self.dataio.dict_fformat + def get_bbox(self) -> BoundingBox3D: """Derive data.bbox for FaultRoomSurface.""" logger.info("Get bbox for FaultRoomSurface") @@ -53,10 +61,7 @@ def get_spec(self) -> FaultRoomSurfaceSpecification: def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for FaultRoomSurface""" return DerivedObjectDescriptor( - subtype="JSON", layout="faultroom_triangulated", efolder="maps", - fmt=(fmt := self.dataio.dict_fformat), - extension=self._validate_get_ext(fmt, "JSON", ValidFormats().dictionary), table_index=None, ) diff --git a/src/fmu/dataio/providers/objectdata/_provider.py b/src/fmu/dataio/providers/objectdata/_provider.py index 049281025..ac0b8b292 100644 --- a/src/fmu/dataio/providers/objectdata/_provider.py +++ b/src/fmu/dataio/providers/objectdata/_provider.py @@ -168,19 +168,24 @@ class DictionaryDataProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.dictionary - def get_spec(self) -> None: - """Derive data.spec for dict.""" + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.dictionary) + + @property + def fmt(self) -> str: + return self.dataio.dict_fformat def get_bbox(self) -> None: """Derive data.bbox for dict.""" + def get_spec(self) -> None: + """Derive data.spec for dict.""" + def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for dict.""" return DerivedObjectDescriptor( - subtype="JSON", layout="dictionary", efolder="dictionaries", - fmt=(fmt := self.dataio.dict_fformat), - extension=self._validate_get_ext(fmt, "JSON", ValidFormats().dictionary), table_index=None, ) diff --git a/src/fmu/dataio/providers/objectdata/_tables.py b/src/fmu/dataio/providers/objectdata/_tables.py index 70205d93e..5b6a30940 100644 --- a/src/fmu/dataio/providers/objectdata/_tables.py +++ b/src/fmu/dataio/providers/objectdata/_tables.py @@ -64,6 +64,17 @@ class DataFrameDataProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.table + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.table) + + @property + def fmt(self) -> str: + return self.dataio.table_fformat + + def get_bbox(self) -> None: + """Derive data.bbox for pd.DataFrame.""" + def get_spec(self) -> TableSpecification: """Derive data.spec for pd.DataFrame.""" logger.info("Get spec for pd.DataFrame (tables)") @@ -72,18 +83,12 @@ def get_spec(self) -> TableSpecification: size=int(self.obj.size), ) - def get_bbox(self) -> None: - """Derive data.bbox for pd.DataFrame.""" - def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for pd.DataFrame.""" table_index = _derive_index(self.dataio.table_index, list(self.obj.columns)) return DerivedObjectDescriptor( - subtype="DataFrame", layout="table", efolder="tables", - fmt=(fmt := self.dataio.table_fformat), - extension=self._validate_get_ext(fmt, "DataFrame", ValidFormats().table), table_index=table_index, ) @@ -96,6 +101,17 @@ class ArrowTableDataProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.table + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.table) + + @property + def fmt(self) -> str: + return self.dataio.arrow_fformat + + def get_bbox(self) -> None: + """Derive data.bbox for pyarrow.Table.""" + def get_spec(self) -> TableSpecification: """Derive data.spec for pyarrow.Table.""" logger.info("Get spec for pyarrow (tables)") @@ -104,17 +120,11 @@ def get_spec(self) -> TableSpecification: size=self.obj.num_columns * self.obj.num_rows, ) - def get_bbox(self) -> None: - """Derive data.bbox for pyarrow.Table.""" - def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data from pyarrow.Table.""" table_index = _derive_index(self.dataio.table_index, self.obj.column_names) return DerivedObjectDescriptor( - subtype="ArrowTable", layout="table", efolder="tables", - fmt=(fmt := self.dataio.arrow_fformat), - extension=self._validate_get_ext(fmt, "ArrowTable", ValidFormats().table), table_index=table_index, ) diff --git a/src/fmu/dataio/providers/objectdata/_xtgeo.py b/src/fmu/dataio/providers/objectdata/_xtgeo.py index 20cfdecb5..34afba29d 100644 --- a/src/fmu/dataio/providers/objectdata/_xtgeo.py +++ b/src/fmu/dataio/providers/objectdata/_xtgeo.py @@ -40,22 +40,13 @@ class RegularSurfaceDataProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.surface - def get_spec(self) -> SurfaceSpecification: - """Derive data.spec for xtgeo.RegularSurface.""" - logger.info("Get spec for RegularSurface") + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.surface) - required = self.obj.metadata.required - return SurfaceSpecification( - ncol=npfloat_to_float(required["ncol"]), - nrow=npfloat_to_float(required["nrow"]), - xori=npfloat_to_float(required["xori"]), - yori=npfloat_to_float(required["yori"]), - xinc=npfloat_to_float(required["xinc"]), - yinc=npfloat_to_float(required["yinc"]), - yflip=npfloat_to_float(required["yflip"]), - rotation=npfloat_to_float(required["rotation"]), - undef=1.0e30, - ) + @property + def fmt(self) -> str: + return self.dataio.surface_fformat def get_bbox(self) -> BoundingBox2D | BoundingBox3D: """ @@ -82,16 +73,28 @@ def get_bbox(self) -> BoundingBox2D | BoundingBox3D: ymax=float(self.obj.ymax), ) + def get_spec(self) -> SurfaceSpecification: + """Derive data.spec for xtgeo.RegularSurface.""" + logger.info("Get spec for RegularSurface") + + required = self.obj.metadata.required + return SurfaceSpecification( + ncol=npfloat_to_float(required["ncol"]), + nrow=npfloat_to_float(required["nrow"]), + xori=npfloat_to_float(required["xori"]), + yori=npfloat_to_float(required["yori"]), + xinc=npfloat_to_float(required["xinc"]), + yinc=npfloat_to_float(required["yinc"]), + yflip=npfloat_to_float(required["yflip"]), + rotation=npfloat_to_float(required["rotation"]), + undef=1.0e30, + ) + def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for xtgeo.RegularSurface.""" return DerivedObjectDescriptor( - subtype="RegularSurface", layout="regular", efolder="maps", - fmt=(fmt := self.dataio.surface_fformat), - extension=self._validate_get_ext( - fmt, "RegularSurface", ValidFormats().surface - ), table_index=None, ) @@ -104,15 +107,13 @@ class PolygonsDataProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.polygons - def get_spec(self) -> PolygonsSpecification: - """Derive data.spec for xtgeo.Polygons.""" - logger.info("Get spec for Polygons") + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.polygons) - return PolygonsSpecification( - npolys=np.unique( - self.obj.get_dataframe(copy=False)[self.obj.pname].values - ).size - ) + @property + def fmt(self) -> str: + return self.dataio.polygons_fformat def get_bbox(self) -> BoundingBox3D: """Derive data.bbox for xtgeo.Polygons""" @@ -128,14 +129,21 @@ def get_bbox(self) -> BoundingBox3D: zmax=float(zmax), ) + def get_spec(self) -> PolygonsSpecification: + """Derive data.spec for xtgeo.Polygons.""" + logger.info("Get spec for Polygons") + + return PolygonsSpecification( + npolys=np.unique( + self.obj.get_dataframe(copy=False)[self.obj.pname].values + ).size + ) + def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for xtgeo.Polygons.""" return DerivedObjectDescriptor( - subtype="Polygons", layout="unset", efolder="polygons", - fmt=(fmt := self.dataio.polygons_fformat), - extension=self._validate_get_ext(fmt, "Polygons", ValidFormats().polygons), table_index=None, ) @@ -144,24 +152,22 @@ def get_objectdata(self) -> DerivedObjectDescriptor: class PointsDataProvider(ObjectDataProvider): obj: xtgeo.Points - @property - def obj_dataframe(self) -> pd.DataFrame: - """Returns a dataframe of the referenced xtgeo.Points object.""" - return self.obj.get_dataframe(copy=False) - @property def classname(self) -> FMUClassEnum: return FMUClassEnum.points - def get_spec(self) -> PointSpecification: - """Derive data.spec for xtgeo.Points.""" - logger.info("Get spec for Points") + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.points) - df = self.obj_dataframe - return PointSpecification( - attributes=list(df.columns[3:]) if len(df.columns) > 3 else None, - size=int(df.size), - ) + @property + def fmt(self) -> str: + return self.dataio.points_fformat + + @property + def obj_dataframe(self) -> pd.DataFrame: + """Returns a dataframe of the referenced xtgeo.Points object.""" + return self.obj.get_dataframe(copy=False) def get_bbox(self) -> BoundingBox3D: """Derive data.bbox for xtgeo.Points.""" @@ -177,14 +183,21 @@ def get_bbox(self) -> BoundingBox3D: zmax=float(df[self.obj.zname].max()), ) + def get_spec(self) -> PointSpecification: + """Derive data.spec for xtgeo.Points.""" + logger.info("Get spec for Points") + + df = self.obj_dataframe + return PointSpecification( + attributes=list(df.columns[3:]) if len(df.columns) > 3 else None, + size=int(df.size), + ) + def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for xtgeo.Points.""" return DerivedObjectDescriptor( - subtype="Points", layout="unset", efolder="points", - fmt=(fmt := self.dataio.points_fformat), - extension=self._validate_get_ext(fmt, "Points", ValidFormats().points), table_index=None, ) @@ -197,26 +210,13 @@ class CubeDataProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.cube - def get_spec(self) -> CubeSpecification: - """Derive data.spec for xtgeo.Cube.""" - logger.info("Get spec for Cube") + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.cube) - required = self.obj.metadata.required - return CubeSpecification( - ncol=npfloat_to_float(required["ncol"]), - nrow=npfloat_to_float(required["nrow"]), - nlay=npfloat_to_float(required["nlay"]), - xori=npfloat_to_float(required["xori"]), - yori=npfloat_to_float(required["yori"]), - zori=npfloat_to_float(required["zori"]), - xinc=npfloat_to_float(required["xinc"]), - yinc=npfloat_to_float(required["yinc"]), - zinc=npfloat_to_float(required["zinc"]), - yflip=npfloat_to_float(required["yflip"]), - zflip=npfloat_to_float(required["zflip"]), - rotation=npfloat_to_float(required["rotation"]), - undef=npfloat_to_float(required["undef"]), - ) + @property + def fmt(self) -> str: + return self.dataio.cube_fformat def get_bbox(self) -> BoundingBox3D: """Derive data.bbox for xtgeo.Cube.""" @@ -248,14 +248,32 @@ def get_bbox(self) -> BoundingBox3D: zmax=float(self.obj.zori + self.obj.zinc * (self.obj.nlay - 1)), ) + def get_spec(self) -> CubeSpecification: + """Derive data.spec for xtgeo.Cube.""" + logger.info("Get spec for Cube") + + required = self.obj.metadata.required + return CubeSpecification( + ncol=npfloat_to_float(required["ncol"]), + nrow=npfloat_to_float(required["nrow"]), + nlay=npfloat_to_float(required["nlay"]), + xori=npfloat_to_float(required["xori"]), + yori=npfloat_to_float(required["yori"]), + zori=npfloat_to_float(required["zori"]), + xinc=npfloat_to_float(required["xinc"]), + yinc=npfloat_to_float(required["yinc"]), + zinc=npfloat_to_float(required["zinc"]), + yflip=npfloat_to_float(required["yflip"]), + zflip=npfloat_to_float(required["zflip"]), + rotation=npfloat_to_float(required["rotation"]), + undef=npfloat_to_float(required["undef"]), + ) + def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for xtgeo.Cube.""" return DerivedObjectDescriptor( - subtype="RegularCube", layout="regular", efolder="cubes", - fmt=(fmt := self.dataio.cube_fformat), - extension=self._validate_get_ext(fmt, "RegularCube", ValidFormats().cube), table_index=None, ) @@ -268,22 +286,13 @@ class CPGridDataProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.cpgrid - def get_spec(self) -> CPGridSpecification: - """Derive data.spec for xtgeo.Grid.""" - logger.info("Get spec for Grid geometry") + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.grid) - required = self.obj.metadata.required - return CPGridSpecification( - ncol=npfloat_to_float(required["ncol"]), - nrow=npfloat_to_float(required["nrow"]), - nlay=npfloat_to_float(required["nlay"]), - xshift=npfloat_to_float(required["xshift"]), - yshift=npfloat_to_float(required["yshift"]), - zshift=npfloat_to_float(required["zshift"]), - xscale=npfloat_to_float(required["xscale"]), - yscale=npfloat_to_float(required["yscale"]), - zscale=npfloat_to_float(required["zscale"]), - ) + @property + def fmt(self) -> str: + return self.dataio.grid_fformat def get_bbox(self) -> BoundingBox3D: """Derive data.bbox for xtgeo.Grid.""" @@ -303,14 +312,28 @@ def get_bbox(self) -> BoundingBox3D: zmax=round(float(geox["zmax"]), 4), ) + def get_spec(self) -> CPGridSpecification: + """Derive data.spec for xtgeo.Grid.""" + logger.info("Get spec for Grid geometry") + + required = self.obj.metadata.required + return CPGridSpecification( + ncol=npfloat_to_float(required["ncol"]), + nrow=npfloat_to_float(required["nrow"]), + nlay=npfloat_to_float(required["nlay"]), + xshift=npfloat_to_float(required["xshift"]), + yshift=npfloat_to_float(required["yshift"]), + zshift=npfloat_to_float(required["zshift"]), + xscale=npfloat_to_float(required["xscale"]), + yscale=npfloat_to_float(required["yscale"]), + zscale=npfloat_to_float(required["zscale"]), + ) + def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for xtgeo.Grid.""" return DerivedObjectDescriptor( - subtype="CPGrid", layout="cornerpoint", efolder="grids", - fmt=(fmt := self.dataio.grid_fformat), - extension=self._validate_get_ext(fmt, "CPGrid", ValidFormats().grid), table_index=None, ) @@ -323,6 +346,17 @@ class CPGridPropertyDataProvider(ObjectDataProvider): def classname(self) -> FMUClassEnum: return FMUClassEnum.cpgrid_property + @property + def extension(self) -> str: + return self._validate_get_ext(self.fmt, ValidFormats.grid) + + @property + def fmt(self) -> str: + return self.dataio.grid_fformat + + def get_bbox(self) -> None: + """Derive data.bbox for xtgeo.GridProperty.""" + def get_spec(self) -> CPGridPropertySpecification: """Derive data.spec for xtgeo.GridProperty.""" logger.info("Get spec for GridProperty") @@ -333,18 +367,10 @@ def get_spec(self) -> CPGridPropertySpecification: nlay=self.obj.nlay, ) - def get_bbox(self) -> None: - """Derive data.bbox for xtgeo.GridProperty.""" - def get_objectdata(self) -> DerivedObjectDescriptor: """Derive object data for xtgeo.GridProperty.""" return DerivedObjectDescriptor( - subtype="CPGridProperty", layout="cornerpoint", efolder="grids", - fmt=(fmt := self.dataio.grid_fformat), - extension=self._validate_get_ext( - fmt, "CPGridProperty", ValidFormats().grid - ), table_index=None, ) diff --git a/src/fmu/dataio/types.py b/src/fmu/dataio/types.py index ffed03189..34997b96c 100644 --- a/src/fmu/dataio/types.py +++ b/src/fmu/dataio/types.py @@ -65,18 +65,6 @@ class PolygonsProxy(Polygons): ... "dictionaries", ] -Subtype: TypeAlias = Literal[ - "RegularSurface", - "Polygons", - "Points", - "RegularCube", - "CPGrid", - "CPGridProperty", - "DataFrame", - "JSON", - "ArrowTable", -] - Layout: TypeAlias = Literal[ "regular", "unset", diff --git a/tests/test_units/test_filedataprovider_class.py b/tests/test_units/test_filedataprovider_class.py index eff443012..b2bd2cde4 100644 --- a/tests/test_units/test_filedataprovider_class.py +++ b/tests/test_units/test_filedataprovider_class.py @@ -215,7 +215,6 @@ def test_filedata_provider(regsurf, tmp_path): objdata = objectdata_provider_factory(regsurf, cfg) objdata.name = "name" objdata.efolder = "efolder" - objdata.extension = ".ext" t1 = "19000101" t2 = "20240101" objdata.time0 = datetime.strptime(t1, "%Y%m%d") @@ -227,9 +226,9 @@ def test_filedata_provider(regsurf, tmp_path): assert isinstance(filemeta, meta.File) assert ( str(filemeta.relative_path) - == f"share/results/efolder/parent--name--tag--{t2}_{t1}.ext" + == f"share/results/efolder/parent--name--tag--{t2}_{t1}.gri" ) - absdata = tmp_path / f"share/results/efolder/parent--name--tag--{t2}_{t1}.ext" + absdata = tmp_path / f"share/results/efolder/parent--name--tag--{t2}_{t1}.gri" assert filemeta.absolute_path == absdata diff --git a/tests/test_units/test_objectdataprovider_class.py b/tests/test_units/test_objectdataprovider_class.py index a95307279..28442a630 100644 --- a/tests/test_units/test_objectdataprovider_class.py +++ b/tests/test_units/test_objectdataprovider_class.py @@ -46,7 +46,7 @@ def test_objectdata_regularsurface_validate_extension(regsurf, edataobj1): """Test a valid extension for RegularSurface object.""" ext = objectdata_provider_factory(regsurf, edataobj1)._validate_get_ext( - "irap_binary", "RegularSurface", ValidFormats().surface + "irap_binary", ValidFormats.surface ) assert ext == ".gri" @@ -57,7 +57,7 @@ def test_objectdata_regularsurface_validate_extension_shall_fail(regsurf, edatao with pytest.raises(ConfigurationError): objectdata_provider_factory(regsurf, edataobj1)._validate_get_ext( - "some_invalid", "RegularSurface", ValidFormats().surface + "some_invalid", ValidFormats.surface ) @@ -79,10 +79,7 @@ def test_objectdata_regularsurface_derive_objectdata(regsurf, edataobj1): objdata = objectdata_provider_factory(regsurf, edataobj1) assert isinstance(objdata, RegularSurfaceDataProvider) assert objdata.classname.value == "surface" - - res = objdata.get_objectdata() - assert res.subtype == "RegularSurface" - assert res.extension == ".gri" + assert objdata.extension == ".gri" def test_objectdata_regularsurface_derive_metadata(regsurf, edataobj1):