forked from equinor/fmu-dataio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CLN: Further split up ObjectData providers
This structure makes adding new object types simpler, more maintainable, and more scalable.
- Loading branch information
Showing
9 changed files
with
162 additions
and
184 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Final | ||
|
||
import pandas as pd | ||
from fmu.dataio._definitions import STANDARD_TABLE_INDEX_COLUMNS, ValidFormats | ||
from fmu.dataio._logging import null_logger | ||
from fmu.dataio.datastructure.meta import specification | ||
|
||
from ._base import ( | ||
DerivedObjectDescriptor, | ||
ObjectDataProvider, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
import pyarrow | ||
|
||
logger: Final = null_logger(__name__) | ||
|
||
|
||
def _check_index_in_columns(index: list[str], columns: list[str]) -> None: | ||
"""Check the table index. | ||
Args: | ||
index (list): list of column names | ||
Raises: | ||
KeyError: if index contains names that are not in self | ||
""" | ||
|
||
not_founds = (item for item in index if item not in columns) | ||
for not_found in not_founds: | ||
raise KeyError(f"{not_found} is not in table") | ||
|
||
|
||
def _derive_index(table_index: list[str] | None, columns: list[str]) -> list[str]: | ||
index = [] | ||
|
||
if table_index is None: | ||
logger.debug("Finding index to include") | ||
for context, standard_cols in STANDARD_TABLE_INDEX_COLUMNS.items(): | ||
for valid_col in standard_cols: | ||
if valid_col in columns: | ||
index.append(valid_col) | ||
if index: | ||
logger.info("Context is %s ", context) | ||
logger.debug("Proudly presenting the index: %s", index) | ||
else: | ||
index = table_index | ||
|
||
if "REAL" in columns: | ||
index.append("REAL") | ||
_check_index_in_columns(index, columns) | ||
return index | ||
|
||
|
||
@dataclass | ||
class DataFrameDataProvider(ObjectDataProvider): | ||
obj: pd.DataFrame | ||
|
||
def get_spec(self) -> dict: | ||
"""Derive data.spec for pd.DataFrame.""" | ||
logger.info("Get spec for pd.DataFrame (tables)") | ||
|
||
return specification.TableSpecification( | ||
columns=list(self.obj.columns), | ||
size=int(self.obj.size), | ||
).model_dump( | ||
mode="json", | ||
exclude_none=True, | ||
) | ||
|
||
def get_bbox(self) -> dict: | ||
"""Derive data.bbox for pd.DataFrame.""" | ||
logger.info("Get bbox for pd.DataFrame (tables)") | ||
return {} | ||
|
||
def get_objectdata(self) -> DerivedObjectDescriptor: | ||
"""Derive object data for pd.DataFrame.""" | ||
table_index = _derive_index(self.dataio.table_index, list(self.obj.columns)) | ||
return DerivedObjectDescriptor( | ||
subtype="DataFrame", | ||
classname="table", | ||
layout="table", | ||
efolder="tables", | ||
fmt=(fmt := self.dataio.table_fformat), | ||
extension=self._validate_get_ext(fmt, "DataFrame", ValidFormats().table), | ||
spec=self.get_spec(), | ||
bbox=self.get_bbox(), | ||
table_index=table_index, | ||
) | ||
|
||
|
||
class ArrowTableDataProvider(ObjectDataProvider): | ||
obj: pyarrow.Table | ||
|
||
def get_spec(self) -> dict: | ||
"""Derive data.spec for pyarrow.Table.""" | ||
logger.info("Get spec for pyarrow (tables)") | ||
|
||
return specification.TableSpecification( | ||
columns=list(self.obj.column_names), | ||
size=self.obj.num_columns * self.obj.num_rows, | ||
).model_dump( | ||
mode="json", | ||
exclude_none=True, | ||
) | ||
|
||
def get_bbox(self) -> dict: | ||
"""Derive data.bbox for pyarrow.Table.""" | ||
logger.info("Get bbox for pyarrow (tables)") | ||
return {} | ||
|
||
def get_objectdata(self) -> DerivedObjectDescriptor: | ||
"""Derive object data from pyarrow.Table.""" | ||
table_index = _derive_index(self.dataio.table_index, self.obj.column_names) | ||
return DerivedObjectDescriptor( | ||
subtype="ArrowTable", | ||
classname="table", | ||
layout="table", | ||
efolder="tables", | ||
fmt=(fmt := self.dataio.arrow_fformat), | ||
extension=self._validate_get_ext(fmt, "ArrowTable", ValidFormats().table), | ||
spec=self.get_spec(), | ||
bbox=self.get_bbox(), | ||
table_index=table_index, | ||
) |
Oops, something went wrong.