diff --git a/doc/source/api.rst b/doc/source/api.rst index fc42564..9680614 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -4,5 +4,53 @@ API Opening Measurement Sets ------------------------ +The standard :func:`xarray.backends.api.open_dataset` and +:func:`xarray.backends.api.open_datatree` methods should +be used to open either a :class:`~xarray.Dataset` or a +:class:`~xarray.DataTree`. + +.. code-block:: python + + >>> dataset = xarray.open_dataset( + "/data/data.ms", + partition_columns=["DATA_DESC_ID", "FIELD_ID"]) + >>> datatree = xarray.backends.api.open_datatree( + "/data/data.ms", + partition_columns=["DATA_DESC_ID", "FIELD_ID"]) + +These methods defer to the relevant methods on the +`Entrypoint Class `_. +Consult the method signatures for information on extra +arguments that can be passed. + + +.. _entrypoint-class: + +Entrypoint Class +---------------- + +Entrypoint class for the MSv2 backend. + .. autoclass:: xarray_ms.backend.msv2.entrypoint.MSv2PartitionEntryPoint :members: open_dataset, open_datatree + + +Reading from Zarr +----------------- + +Thin wrappers around :func:`xarray.Dataset.open_zarr` +and :func:`xarray.DataTree.open_zarr` that encode +:class:`~xarray.Dataset` attributes as JSON. + +.. autofunction:: xarray_ms.xds_from_zarr +.. autofunction:: xarray_ms.xdt_from_zarr + +Writing to Zarr +--------------- + +Thin wrappers around :func:`xarray.Dataset.to_zarr` +and :func:`xarray.DataTree.to_zarr` that encode +:class:`~xarray.Dataset` attributes as JSON. + +.. autofunction:: xarray_ms.xds_to_zarr +.. autofunction:: xarray_ms.xdt_to_zarr diff --git a/pyproject.toml b/pyproject.toml index 77d0f55..2e65841 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,9 +14,10 @@ distributed = {version = "^2024.5.0", optional = true, extras = ["testing"]} cacheout = "^0.16.0" arcae = "^0.2.4" typing-extensions = { version = "^4.12.2", python = "<3.11" } +zarr = {version = "^2.18.3", optional = true, extras = ["testing"]} [tool.poetry.extras] -testing = ["dask", "distributed", "pytest"] +testing = ["dask", "distributed", "pytest", "zarr"] [tool.poetry.plugins."xarray.backends"] "xarray-ms:msv2" = "xarray_ms.backend.msv2.entrypoint:MSv2PartitionEntryPoint" diff --git a/tests/test_zarr_roundtrip.py b/tests/test_zarr_roundtrip.py new file mode 100644 index 0000000..4f1dadc --- /dev/null +++ b/tests/test_zarr_roundtrip.py @@ -0,0 +1,20 @@ +import xarray.testing as xt +from xarray.backends.api import open_dataset, open_datatree + +from xarray_ms import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr + + +def test_dataset_roundtrip(simmed_ms, tmp_path): + ds = open_dataset(simmed_ms) + zarr_path = tmp_path / "test_dataset.zarr" + xds_to_zarr(ds, zarr_path, compute=True, consolidated=True) + ds2 = xds_from_zarr(zarr_path) + xt.assert_identical(ds, ds2) + + +def test_datatree_roundtrip(simmed_ms, tmp_path): + dt = open_datatree(simmed_ms) + zarr_path = tmp_path / "test_datatree.zarr" + xdt_to_zarr(dt, zarr_path, compute=True, consolidated=True) + dt2 = xdt_from_zarr(zarr_path) + xt.assert_identical(dt, dt2) diff --git a/xarray_ms/__init__.py b/xarray_ms/__init__.py index e69de29..157430f 100644 --- a/xarray_ms/__init__.py +++ b/xarray_ms/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["xds_from_zarr", "xdt_from_zarr", "xds_to_zarr", "xdt_to_zarr"] + +from xarray_ms.core import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index 6f44edb..8812384 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -18,6 +18,7 @@ from xarray_ms.backend.msv2.antenna_dataset_factory import AntennaDatasetFactory from xarray_ms.backend.msv2.main_dataset_factory import MainDatasetFactory from xarray_ms.backend.msv2.structure import ( + DEFAULT_PARTITION_COLUMNS, MSv2Structure, MSv2StructureFactory, ) @@ -30,14 +31,7 @@ from xarray.backends.common import AbstractDataStore - from xarray_ms.backend.msv2.structure import PartitionKeyT - - -DEFAULT_PARTITION_COLUMNS: List[str] = [ - "DATA_DESC_ID", - "FIELD_ID", - "OBSERVATION_ID", -] + from xarray_ms.backend.msv2.structure import DEFAULT_PARTITION_COLUMNS, PartitionKeyT def promote_chunks( diff --git a/xarray_ms/backend/msv2/structure.py b/xarray_ms/backend/msv2/structure.py index b590119..3110f1d 100644 --- a/xarray_ms/backend/msv2/structure.py +++ b/xarray_ms/backend/msv2/structure.py @@ -103,6 +103,13 @@ def is_partition_key(key: PartitionKeyT) -> bool: ) +DEFAULT_PARTITION_COLUMNS: List[str] = [ + "DATA_DESC_ID", + "FIELD_ID", + "OBSERVATION_ID", +] + + SHORT_TO_LONG_PARTITION_COLUMNS: Dict[str, str] = { "D": "DATA_DESC_ID", "F": "FIELD_ID", diff --git a/xarray_ms/core.py b/xarray_ms/core.py new file mode 100644 index 0000000..b1418cd --- /dev/null +++ b/xarray_ms/core.py @@ -0,0 +1,89 @@ +import json + +import xarray +from xarray.backends.api import open_datatree +from xarray.core.dataset import Dataset +from xarray.core.datatree import DataTree + +try: + import zarr +except ImportError: + zarr = None + + +def encode_attributes(ds: Dataset) -> Dataset: + """Encode the antenna_xds attribute of a Dataset.""" + + # Attempt to encode the the antenna_xds attribute + ant_xds = ds.attrs.get("antenna_xds", None) + if ant_xds is None: + return ds + elif isinstance(ant_xds, Dataset): + ant_xds = json.dumps(ant_xds.to_dict()) + return ds.assign_attrs(antenna_xds=ant_xds) + else: + raise TypeError( + f"antenna_xds attribute must be an xarray Dataset " + f"but a {type(ant_xds)} was present" + ) + + +def decode_attributes(ds: Dataset) -> Dataset: + """Decode the antenna_xds attribute of a Dataset.""" + # Attempt to decode the the antenna_xds attribute + ant_xds = ds.attrs["antenna_xds"] + if isinstance(ant_xds, str): + antenna_dict = json.loads(ant_xds) + ant_ds = Dataset.from_dict(antenna_dict) + return ds.assign_attrs(antenna_xds=ant_ds) + elif isinstance(ant_xds, Dataset): + return ds + else: + raise TypeError( + f"antenna_xds must be an xarray Dataset or a JSON encoded Dataset " + f"but a {type(ant_xds)} was present" + ) + + +def xds_from_zarr(*args, **kwargs): + """Read a Measurement Set-like :class:`~xarray.Dataset` from a Zarr store. + + Thin wrapper around :func:`xarray.open_zarr`.""" + if zarr is None: + raise ImportError("pip install zarr") + + return decode_attributes(xarray.open_zarr(*args, **kwargs)) + + +def xds_to_zarr(ds: Dataset, *args, **kwargs) -> None: + """Write a Measurement Set-like :class:`~xarray.Dataset` to a Zarr store. + + Thin wrapper around :meth:`xarray.Dataset.to_zarr`. + """ + if zarr is None: + raise ImportError("pip install zarr") + + return encode_attributes(ds).to_zarr(*args, **kwargs) + + +def xdt_from_zarr(*args, **kwargs): + """Read a Measurement Set-like :class:`~xarray.core.datatree.DataTree` + from a Zarr store. + + Thin wrapper around :func:`xarray.backends.api.open_datatree`.""" + if zarr is None: + raise ImportError("pip install zarr") + + return open_datatree(*args, **kwargs).map_over_subtree(decode_attributes) + + +def xdt_to_zarr(dt: DataTree, *args, **kwargs) -> None: + """Read a Measurement Set-like :class:`~xarray.core.datatree.DataTree` + to a Zarr store + + Thin wrapper around :meth:`xarray.core.datatree.DataTree.to_zarr`. + """ + if zarr is None: + raise ImportError("pip install zarr") + + return dt.map_over_subtree(encode_attributes).to_zarr(*args, **kwargs)