Skip to content

Commit

Permalink
Rebase output off of GribOutputBase
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Oct 15, 2024
1 parent 3779496 commit 47bfa5c
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions src/ai_models_multio/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import multio
from ai_models.model import Timer
from ai_models.outputs import Output
from ai_models.outputs import GribOutputBase

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -53,27 +53,28 @@ def earthkit_to_multio(metadata: Metadata):
return metad


class MultioOutput(Output):
class MultioOutput(GribOutputBase):
"""Multio Output plugin for ai-models"""

_server: multio.Multio = None

def __init__(self, owner, path: str, metadata: dict, plan: PLANS = "to_file", **_):
"""Multio Output plugin for ai-models"""
"""Multio Output plugin for ai-models
Parameters
----------
plan : PLANS, optional
Multio Plan to use, by default "to_file"
"""

self._plan_name = plan
self._path = path
self._owner = owner

metadata.setdefault("stream", "oper")
metadata.setdefault("expver", owner.expver)
metadata.setdefault("type", "fc")
metadata.setdefault("class", "ml")
metadata.setdefault("gribEdition", "2")

self.metadata = metadata
super().__init__(owner, path, metadata)

def get_plan(self, data: np.ndarray, metadata: Metadata) -> multio.plans.Config:
"""Get the plan for the output"""
return get_plan(self._plan_name, values=data, metadata=metadata, path=self._path)
return get_plan(self._plan_name, values=data, metadata=metadata, path=self.path)

def server(self, data: np.ndarray, metadata: dict) -> multio.Multio:
"""Get multio server, with plan configured from data, metadata and path"""
Expand All @@ -86,6 +87,7 @@ def server(self, data: np.ndarray, metadata: dict) -> multio.Multio:

def write(self, data: np.ndarray, *, check_nans: bool = False, **kwargs):
"""Write data to multio"""
del check_nans # Unused

# Skip if data is None
if data is None:
Expand All @@ -95,15 +97,15 @@ def write(self, data: np.ndarray, *, check_nans: bool = False, **kwargs):
step: int = kwargs.pop("step")

metadata_template = dict(earthkit_to_multio(template_metadata))
metadata_template.update(self.metadata)
metadata_template.update(self.grib_keys)
metadata_template.update(kwargs)

metadata_template.update(
{
"step": step,
"trigger": "step",
"globalSize": math.prod(data.shape),
"generatingProcessIdentifier": self._owner.version,
"generatingProcessIdentifier": self.owner.version,
}
)
with self.server(data, metadata_template) as server:
Expand Down

0 comments on commit 47bfa5c

Please sign in to comment.