Skip to content

Commit

Permalink
Encode Dataset attributes containing Datasets as JSON (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Sep 11, 2024
1 parent ec064c2 commit 94b6630
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 9 deletions.
48 changes: 48 additions & 0 deletions doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,53 @@ API
Opening Measurement Sets
------------------------

The standard :func:`xarray.backends.api.open_dataset` and
:func:`xarray.backends.api.open_datatree` methods should
be used to open either a :class:`~xarray.Dataset` or a
:class:`~xarray.DataTree`.

.. code-block:: python
>>> dataset = xarray.open_dataset(
"/data/data.ms",
partition_columns=["DATA_DESC_ID", "FIELD_ID"])
>>> datatree = xarray.backends.api.open_datatree(
"/data/data.ms",
partition_columns=["DATA_DESC_ID", "FIELD_ID"])
These methods defer to the relevant methods on the
`Entrypoint Class <entrypoint-class_>`_.
Consult the method signatures for information on extra
arguments that can be passed.


.. _entrypoint-class:

Entrypoint Class
----------------

Entrypoint class for the MSv2 backend.

.. autoclass:: xarray_ms.backend.msv2.entrypoint.MSv2PartitionEntryPoint
:members: open_dataset, open_datatree


Reading from Zarr
-----------------

Thin wrappers around :func:`xarray.Dataset.open_zarr`
and :func:`xarray.DataTree.open_zarr` that encode
:class:`~xarray.Dataset` attributes as JSON.

.. autofunction:: xarray_ms.xds_from_zarr
.. autofunction:: xarray_ms.xdt_from_zarr

Writing to Zarr
---------------

Thin wrappers around :func:`xarray.Dataset.to_zarr`
and :func:`xarray.DataTree.to_zarr` that encode
:class:`~xarray.Dataset` attributes as JSON.

.. autofunction:: xarray_ms.xds_to_zarr
.. autofunction:: xarray_ms.xdt_to_zarr
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ distributed = {version = "^2024.5.0", optional = true, extras = ["testing"]}
cacheout = "^0.16.0"
arcae = "^0.2.4"
typing-extensions = { version = "^4.12.2", python = "<3.11" }
zarr = {version = "^2.18.3", optional = true, extras = ["testing"]}

[tool.poetry.extras]
testing = ["dask", "distributed", "pytest"]
testing = ["dask", "distributed", "pytest", "zarr"]

[tool.poetry.plugins."xarray.backends"]
"xarray-ms:msv2" = "xarray_ms.backend.msv2.entrypoint:MSv2PartitionEntryPoint"
Expand Down
20 changes: 20 additions & 0 deletions tests/test_zarr_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import xarray.testing as xt
from xarray.backends.api import open_dataset, open_datatree

from xarray_ms import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr


def test_dataset_roundtrip(simmed_ms, tmp_path):
ds = open_dataset(simmed_ms)
zarr_path = tmp_path / "test_dataset.zarr"
xds_to_zarr(ds, zarr_path, compute=True, consolidated=True)
ds2 = xds_from_zarr(zarr_path)
xt.assert_identical(ds, ds2)


def test_datatree_roundtrip(simmed_ms, tmp_path):
dt = open_datatree(simmed_ms)
zarr_path = tmp_path / "test_datatree.zarr"
xdt_to_zarr(dt, zarr_path, compute=True, consolidated=True)
dt2 = xdt_from_zarr(zarr_path)
xt.assert_identical(dt, dt2)
3 changes: 3 additions & 0 deletions xarray_ms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__all__ = ["xds_from_zarr", "xdt_from_zarr", "xds_to_zarr", "xdt_to_zarr"]

from xarray_ms.core import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr
10 changes: 2 additions & 8 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from xarray_ms.backend.msv2.antenna_dataset_factory import AntennaDatasetFactory
from xarray_ms.backend.msv2.main_dataset_factory import MainDatasetFactory
from xarray_ms.backend.msv2.structure import (
DEFAULT_PARTITION_COLUMNS,
MSv2Structure,
MSv2StructureFactory,
)
Expand All @@ -30,14 +31,7 @@

from xarray.backends.common import AbstractDataStore

from xarray_ms.backend.msv2.structure import PartitionKeyT


DEFAULT_PARTITION_COLUMNS: List[str] = [
"DATA_DESC_ID",
"FIELD_ID",
"OBSERVATION_ID",
]
from xarray_ms.backend.msv2.structure import DEFAULT_PARTITION_COLUMNS, PartitionKeyT


def promote_chunks(
Expand Down
7 changes: 7 additions & 0 deletions xarray_ms/backend/msv2/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def is_partition_key(key: PartitionKeyT) -> bool:
)


DEFAULT_PARTITION_COLUMNS: List[str] = [
"DATA_DESC_ID",
"FIELD_ID",
"OBSERVATION_ID",
]


SHORT_TO_LONG_PARTITION_COLUMNS: Dict[str, str] = {
"D": "DATA_DESC_ID",
"F": "FIELD_ID",
Expand Down
89 changes: 89 additions & 0 deletions xarray_ms/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json

import xarray
from xarray.backends.api import open_datatree
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree

try:
import zarr
except ImportError:
zarr = None


def encode_attributes(ds: Dataset) -> Dataset:
"""Encode the antenna_xds attribute of a Dataset."""

# Attempt to encode the the antenna_xds attribute
ant_xds = ds.attrs.get("antenna_xds", None)
if ant_xds is None:
return ds
elif isinstance(ant_xds, Dataset):
ant_xds = json.dumps(ant_xds.to_dict())
return ds.assign_attrs(antenna_xds=ant_xds)
else:
raise TypeError(
f"antenna_xds attribute must be an xarray Dataset "
f"but a {type(ant_xds)} was present"
)


def decode_attributes(ds: Dataset) -> Dataset:
"""Decode the antenna_xds attribute of a Dataset."""
# Attempt to decode the the antenna_xds attribute
ant_xds = ds.attrs["antenna_xds"]
if isinstance(ant_xds, str):
antenna_dict = json.loads(ant_xds)
ant_ds = Dataset.from_dict(antenna_dict)
return ds.assign_attrs(antenna_xds=ant_ds)
elif isinstance(ant_xds, Dataset):
return ds
else:
raise TypeError(
f"antenna_xds must be an xarray Dataset or a JSON encoded Dataset "
f"but a {type(ant_xds)} was present"
)


def xds_from_zarr(*args, **kwargs):
"""Read a Measurement Set-like :class:`~xarray.Dataset` from a Zarr store.
Thin wrapper around :func:`xarray.open_zarr`."""
if zarr is None:
raise ImportError("pip install zarr")

return decode_attributes(xarray.open_zarr(*args, **kwargs))


def xds_to_zarr(ds: Dataset, *args, **kwargs) -> None:
"""Write a Measurement Set-like :class:`~xarray.Dataset` to a Zarr store.
Thin wrapper around :meth:`xarray.Dataset.to_zarr`.
"""
if zarr is None:
raise ImportError("pip install zarr")

return encode_attributes(ds).to_zarr(*args, **kwargs)


def xdt_from_zarr(*args, **kwargs):
"""Read a Measurement Set-like :class:`~xarray.core.datatree.DataTree`
from a Zarr store.
Thin wrapper around :func:`xarray.backends.api.open_datatree`."""
if zarr is None:
raise ImportError("pip install zarr")

return open_datatree(*args, **kwargs).map_over_subtree(decode_attributes)


def xdt_to_zarr(dt: DataTree, *args, **kwargs) -> None:
"""Read a Measurement Set-like :class:`~xarray.core.datatree.DataTree`
to a Zarr store
Thin wrapper around :meth:`xarray.core.datatree.DataTree.to_zarr`.
"""
if zarr is None:
raise ImportError("pip install zarr")

return dt.map_over_subtree(encode_attributes).to_zarr(*args, **kwargs)

0 comments on commit 94b6630

Please sign in to comment.