Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode Dataset attributes containing Datasets as JSON #10

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,53 @@ API
Opening Measurement Sets
------------------------

The standard :func:`xarray.backends.api.open_dataset` and
:func:`xarray.backends.api.open_datatree` methods should
be used to open either a :class:`~xarray.Dataset` or a
:class:`~xarray.DataTree`.

.. code-block:: python

>>> dataset = xarray.open_dataset(
"/data/data.ms",
partition_columns=["DATA_DESC_ID", "FIELD_ID"])
>>> datatree = xarray.backends.api.open_datatree(
"/data/data.ms",
partition_columns=["DATA_DESC_ID", "FIELD_ID"])

These methods defer to the relevant methods on the
`Entrypoint Class <entrypoint-class_>`_.
Consult the method signatures for information on extra
arguments that can be passed.


.. _entrypoint-class:

Entrypoint Class
----------------

Entrypoint class for the MSv2 backend.

.. autoclass:: xarray_ms.backend.msv2.entrypoint.MSv2PartitionEntryPoint
:members: open_dataset, open_datatree


Reading from Zarr
-----------------

Thin wrappers around :func:`xarray.Dataset.open_zarr`
and :func:`xarray.DataTree.open_zarr` that encode
:class:`~xarray.Dataset` attributes as JSON.

.. autofunction:: xarray_ms.xds_from_zarr
.. autofunction:: xarray_ms.xdt_from_zarr

Writing to Zarr
---------------

Thin wrappers around :func:`xarray.Dataset.to_zarr`
and :func:`xarray.DataTree.to_zarr` that encode
:class:`~xarray.Dataset` attributes as JSON.

.. autofunction:: xarray_ms.xds_to_zarr
.. autofunction:: xarray_ms.xdt_to_zarr
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ distributed = {version = "^2024.5.0", optional = true, extras = ["testing"]}
cacheout = "^0.16.0"
arcae = "^0.2.4"
typing-extensions = { version = "^4.12.2", python = "<3.11" }
zarr = {version = "^2.18.3", optional = true, extras = ["testing"]}

[tool.poetry.extras]
testing = ["dask", "distributed", "pytest"]
testing = ["dask", "distributed", "pytest", "zarr"]

[tool.poetry.plugins."xarray.backends"]
"xarray-ms:msv2" = "xarray_ms.backend.msv2.entrypoint:MSv2PartitionEntryPoint"
Expand Down
20 changes: 20 additions & 0 deletions tests/test_zarr_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import xarray.testing as xt
from xarray.backends.api import open_dataset, open_datatree

from xarray_ms import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr


def test_dataset_roundtrip(simmed_ms, tmp_path):
ds = open_dataset(simmed_ms)
zarr_path = tmp_path / "test_dataset.zarr"
xds_to_zarr(ds, zarr_path, compute=True, consolidated=True)
ds2 = xds_from_zarr(zarr_path)
xt.assert_identical(ds, ds2)


def test_datatree_roundtrip(simmed_ms, tmp_path):
dt = open_datatree(simmed_ms)
zarr_path = tmp_path / "test_datatree.zarr"
xdt_to_zarr(dt, zarr_path, compute=True, consolidated=True)
dt2 = xdt_from_zarr(zarr_path)
xt.assert_identical(dt, dt2)
3 changes: 3 additions & 0 deletions xarray_ms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__all__ = ["xds_from_zarr", "xdt_from_zarr", "xds_to_zarr", "xdt_to_zarr"]

from xarray_ms.core import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr
10 changes: 2 additions & 8 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from xarray_ms.backend.msv2.antenna_dataset_factory import AntennaDatasetFactory
from xarray_ms.backend.msv2.main_dataset_factory import MainDatasetFactory
from xarray_ms.backend.msv2.structure import (
DEFAULT_PARTITION_COLUMNS,
MSv2Structure,
MSv2StructureFactory,
)
Expand All @@ -30,14 +31,7 @@

from xarray.backends.common import AbstractDataStore

from xarray_ms.backend.msv2.structure import PartitionKeyT


DEFAULT_PARTITION_COLUMNS: List[str] = [
"DATA_DESC_ID",
"FIELD_ID",
"OBSERVATION_ID",
]
from xarray_ms.backend.msv2.structure import DEFAULT_PARTITION_COLUMNS, PartitionKeyT


def promote_chunks(
Expand Down
7 changes: 7 additions & 0 deletions xarray_ms/backend/msv2/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def is_partition_key(key: PartitionKeyT) -> bool:
)


DEFAULT_PARTITION_COLUMNS: List[str] = [
"DATA_DESC_ID",
"FIELD_ID",
"OBSERVATION_ID",
]


SHORT_TO_LONG_PARTITION_COLUMNS: Dict[str, str] = {
"D": "DATA_DESC_ID",
"F": "FIELD_ID",
Expand Down
89 changes: 89 additions & 0 deletions xarray_ms/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json

import xarray
from xarray.backends.api import open_datatree
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree

try:
import zarr
except ImportError:
zarr = None


def encode_attributes(ds: Dataset) -> Dataset:
"""Encode the antenna_xds attribute of a Dataset."""

# Attempt to encode the the antenna_xds attribute
ant_xds = ds.attrs.get("antenna_xds", None)
if ant_xds is None:
return ds
elif isinstance(ant_xds, Dataset):
ant_xds = json.dumps(ant_xds.to_dict())
return ds.assign_attrs(antenna_xds=ant_xds)
else:
raise TypeError(
f"antenna_xds attribute must be an xarray Dataset "
f"but a {type(ant_xds)} was present"
)


def decode_attributes(ds: Dataset) -> Dataset:
"""Decode the antenna_xds attribute of a Dataset."""
# Attempt to decode the the antenna_xds attribute
ant_xds = ds.attrs["antenna_xds"]
if isinstance(ant_xds, str):
antenna_dict = json.loads(ant_xds)
ant_ds = Dataset.from_dict(antenna_dict)
return ds.assign_attrs(antenna_xds=ant_ds)
elif isinstance(ant_xds, Dataset):
return ds
else:
raise TypeError(
f"antenna_xds must be an xarray Dataset or a JSON encoded Dataset "
f"but a {type(ant_xds)} was present"
)


def xds_from_zarr(*args, **kwargs):
"""Read a Measurement Set-like :class:`~xarray.Dataset` from a Zarr store.

Thin wrapper around :func:`xarray.open_zarr`."""
if zarr is None:
raise ImportError("pip install zarr")

return decode_attributes(xarray.open_zarr(*args, **kwargs))


def xds_to_zarr(ds: Dataset, *args, **kwargs) -> None:
"""Write a Measurement Set-like :class:`~xarray.Dataset` to a Zarr store.

Thin wrapper around :meth:`xarray.Dataset.to_zarr`.
"""
if zarr is None:
raise ImportError("pip install zarr")

return encode_attributes(ds).to_zarr(*args, **kwargs)


def xdt_from_zarr(*args, **kwargs):
"""Read a Measurement Set-like :class:`~xarray.core.datatree.DataTree`
from a Zarr store.

Thin wrapper around :func:`xarray.backends.api.open_datatree`."""
if zarr is None:
raise ImportError("pip install zarr")

return open_datatree(*args, **kwargs).map_over_subtree(decode_attributes)


def xdt_to_zarr(dt: DataTree, *args, **kwargs) -> None:
"""Read a Measurement Set-like :class:`~xarray.core.datatree.DataTree`
to a Zarr store

Thin wrapper around :meth:`xarray.core.datatree.DataTree.to_zarr`.
"""
if zarr is None:
raise ImportError("pip install zarr")

return dt.map_over_subtree(encode_attributes).to_zarr(*args, **kwargs)