Skip to content

Commit

Permalink
Utilitise xarray's preferred_chunks functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Nov 5, 2024
1 parent 7e1f8cd commit 8f5a935
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ to be developed on well-understood MSv2 data.
>>> import xarray_ms
>>> from xarray.backends.api import datatree
>>> dt = open_datatree("/data/L795830_SB001_uv.MS/",
partition_chunks={"time": 2000, "baseline": 1000})
preferred_chunks={"time": 2000, "baseline": 1000})
>>> dt
<xarray.DataTree>
Group: /
Expand Down
11 changes: 6 additions & 5 deletions doc/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,22 @@ Per-partition chunking

Different chunking may be desired, especially when applied to
different channelisation and polarisation configurations.
In these cases, the ``partition_chunks`` argument can be used
In these cases, the ``preferred_chunks`` argument can be used
to specify different chunking setups for each partition.

.. ipython:: python
dt = open_datatree(ms, partition_columns=[
"DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"],
partition_chunks={
chunks={},
preferred_chunks={
(("DATA_DESC_ID", 0),): {"time": 2, "frequency": 4},
(("DATA_DESC_ID", 1),): {"time": 3, "frequency": 2}})
See the ``partition_chunks`` argument of
See the ``preferred_chunks`` argument of
:meth:`xarray_ms.backend.msv2.entrypoint.MSv2EntryPoint.open_datatree`
for more information.


.. ipython:: python
dt
Expand All @@ -139,7 +139,8 @@ this to a zarr_ store.
dt = open_datatree(ms, partition_columns=[
"DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"],
partition_chunks={
chunks={},
preferred_chunks={
(("DATA_DESC_ID", 0),): {"time": 2, "frequency": 4},
(("DATA_DESC_ID", 1),): {"time": 3, "frequency": 2}})
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ readme = "README.rst"
[tool.poetry.dependencies]
python = "^3.10"
pytest = {version = "^8.0.0", optional = true, extras = ["testing"]}
xarray = "^2024.9.0, < 2024.10.0"
xarray = "^2024.9.0"
dask = {version = "^2024.5.0", optional = true, extras = ["testing"]}
distributed = {version = "^2024.5.0", optional = true, extras = ["testing"]}
cacheout = "^0.16.0"
Expand Down
28 changes: 15 additions & 13 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_open_datatree(simmed_ms):

# Works with default dask scheduler
with ExitStack() as stack:
dt = open_datatree(simmed_ms, partition_chunks=chunks)
dt = open_datatree(simmed_ms, preferred_chunks=chunks)
for ds in dt.values():
del ds.attrs["creation_date"]
xt.assert_identical(dt, mem_dt)
Expand All @@ -165,7 +165,7 @@ def test_open_datatree(simmed_ms):
with ExitStack() as stack:
cluster = stack.enter_context(LocalCluster(processes=True, n_workers=4))
stack.enter_context(Client(cluster))
dt = open_datatree(simmed_ms, partition_chunks=chunks)
dt = open_datatree(simmed_ms, preferred_chunks=chunks)
for ds in dt.values():
del ds.attrs["creation_date"]
xt.assert_identical(dt, mem_dt)
Expand All @@ -186,7 +186,8 @@ def test_open_datatree_chunking(simmed_ms):
and partition-specific chunking"""
dt = open_datatree(
simmed_ms,
partition_chunks={"time": 3, "frequency": 2},
chunks={},
preferred_chunks={"time": 3, "frequency": 2},
)

for child in dt.children:
Expand All @@ -210,7 +211,8 @@ def test_open_datatree_chunking(simmed_ms):

dt = open_datatree(
simmed_ms,
partition_chunks={
chunks={},
preferred_chunks={
"D=0": {"time": 2, "baseline": 2},
"D=1": {"time": 3, "frequency": 2},
},
Expand All @@ -235,12 +237,12 @@ def test_open_datatree_chunking(simmed_ms):
"uvw_label": (3,),
}

with pytest.warns(UserWarning, match="`partition_chunks` overriding `chunks`"):
dt = open_datatree(
simmed_ms,
chunks={},
partition_chunks={
"D=0": {"time": 2, "baseline": 2},
"D=1": {"time": 3, "frequency": 2},
},
)
# with pytest.warns(UserWarning, match="`preferred_chunks` overriding `chunks`"):
# dt = open_datatree(
# simmed_ms,
# chunks={},
# preferred_chunks={
# "D=0": {"time": 2, "baseline": 2},
# "D=1": {"time": 3, "frequency": 2},
# },
# )
51 changes: 33 additions & 18 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class MSv2Store(AbstractWritableDataStore):
"_structure_factory",
"_partition_columns",
"_partition_key",
"_preferred_chunks",
"_auto_corrs",
"_ninstances",
"_epoch",
Expand All @@ -113,6 +114,7 @@ class MSv2Store(AbstractWritableDataStore):
_table_factory: TableFactory
_structure_factory: MSv2StructureFactory
_partition_columns: List[str]
_preferred_chunks: Dict[str, int]
_partition: PartitionKeyT
_autocorrs: bool
_ninstances: int
Expand All @@ -124,6 +126,7 @@ def __init__(
structure_factory: MSv2StructureFactory,
partition_columns: List[str],
partition_key: PartitionKeyT,
preferred_chunks: Dict[str, int],
auto_corrs: bool,
ninstances: int,
epoch: str,
Expand All @@ -132,6 +135,7 @@ def __init__(
self._structure_factory = structure_factory
self._partition_columns = partition_columns
self._partition_key = partition_key
self._preferred_chunks = preferred_chunks
self._auto_corrs = auto_corrs
self._ninstances = ninstances
self._epoch = epoch
Expand All @@ -143,6 +147,7 @@ def open(
drop_variables=None,
partition_columns: List[str] | None = None,
partition_key: PartitionKeyT | None = None,
preferred_chunks: Dict[str, int] | None = None,
auto_corrs: bool = True,
ninstances: int = 1,
epoch: str | None = None,
Expand Down Expand Up @@ -177,11 +182,15 @@ def open(
)
partition_key = first_key

if preferred_chunks is None:
preferred_chunks = preferred_chunks or {}

return cls(
table_factory,
structure_factory,
partition_columns=partition_columns,
partition_key=partition_key,
preferred_chunks=preferred_chunks,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
Expand All @@ -191,9 +200,14 @@ def close(self, **kwargs):
pass

def get_variables(self):
return MainDatasetFactory(
self._partition_key, self._table_factory, self._structure_factory
).get_variables()
factory = MainDatasetFactory(
self._partition_key,
self._preferred_chunks,
self._table_factory,
self._structure_factory,
)

return factory.get_variables()

def get_attrs(self):
try:
Expand All @@ -219,6 +233,7 @@ class MSv2EntryPoint(BackendEntrypoint):
"filename_or_obj",
"partition_columns",
"partition_key",
"preferred_chunks",
"auto_corrs",
"ninstances",
"epoch",
Expand Down Expand Up @@ -253,6 +268,7 @@ def open_dataset(
drop_variables: str | Iterable[str] | None = None,
partition_columns: List[str] | None = None,
partition_key: PartitionKeyT | None = None,
preferred_chunks: Dict[str, int] | None = None,
auto_corrs: bool = True,
ninstances: int = 8,
epoch: str | None = None,
Expand All @@ -269,6 +285,7 @@ def open_dataset(
partition_key: A key corresponding to an individual partition.
For example :code:`(('DATA_DESC_ID', 0), ('FIELD_ID', 0))`.
If :code:`None`, the first partition will be opened.
preferred_chunks: The preferred chunks for each partition.
auto_corrs: Include/Exclude auto-correlations.
ninstances: The number of Measurement Set instances to open for parallel I/O.
epoch: A unique string identifying the creation of this Dataset.
Expand All @@ -281,11 +298,13 @@ def open_dataset(
partition specified by :code:`partition_columns` and :code:`partition_key`.
"""
filename_or_obj = _normalize_path(filename_or_obj)

store = MSv2Store.open(
filename_or_obj,
drop_variables=drop_variables,
partition_columns=partition_columns,
partition_key=partition_key,
preferred_chunks=preferred_chunks,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
Expand All @@ -299,7 +318,7 @@ def open_datatree(
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
partition_chunks: Dict[str, Any] | None = None,
preferred_chunks: Dict[str, Any] | None = None,
drop_variables: str | Iterable[str] | None = None,
partition_columns: List[str] | None = None,
auto_corrs: bool = True,
Expand All @@ -312,7 +331,7 @@ def open_datatree(
Args:
filename_or_obj: The path to the MSv2 CASA Measurement Set file.
partition_chunks: Chunk sizes along each dimension,
preferred_chunks: Chunk sizes along each dimension,
e.g. :code:`{{"time": 10, "frequency": 16}}`.
Individual partitions can be chunked differently by
partially (or fully) specifying a partition key: e.g.
Expand All @@ -332,10 +351,9 @@ def open_datatree(
"D=0,F=1": {{"time": 20, "frequency": 32}},
}}
.. note:: This argument overrides the reserved ``chunks`` argument
used by xarray to control chunking in Datasets and DataTrees.
It should be used instead of ``chunks`` when different
chunking is desired for different partitions.
.. note:: This argument should be used in conjunction with
the reserved ``chunks`` argument used by xarray to control chunking
in Datasets and DataTrees. See preferred_chunk_sizes_ for more information.
drop_variables: Variables to drop from the dataset.
partition_columns: The columns to use for partitioning the Measurement set.
Expand All @@ -347,6 +365,8 @@ def open_datatree(
Returns:
An xarray :class:`~xarray.core.datatree.DataTree`
.. _preferred_chunk_sizes: https://docs.xarray.dev/en/stable/internals/how-to-add-new-backend.html#preferred-chunk-sizes
"""
if isinstance(filename_or_obj, os.PathLike):
ms = str(filename_or_obj)
Expand All @@ -361,14 +381,7 @@ def open_datatree(

structure = structure_factory()
datasets = {}

if not partition_chunks:
partition_chunks = kwargs.pop("chunks", None)
elif "chunks" in kwargs:
kwargs.pop("chunks", None)
warnings.warn("`partition_chunks` overriding `chunks`")

pchunks = promote_chunks(structure, partition_chunks)
pchunks = promote_chunks(structure, preferred_chunks)

for partition_key in structure:
ds = xarray.open_dataset(
Expand All @@ -377,11 +390,13 @@ def open_datatree(
engine="xarray-ms:msv2",
partition_columns=partition_columns,
partition_key=partition_key,
preferred_chunks=pchunks[partition_key]
if isinstance(pchunks, Mapping)
else pchunks,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
structure_factory=structure_factory,
chunks=pchunks[partition_key] if isinstance(pchunks, Mapping) else pchunks,
**kwargs,
)

Expand Down
13 changes: 10 additions & 3 deletions xarray_ms/backend/msv2/main_dataset_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import warnings
from typing import Any, Mapping, Tuple, Type
from typing import Any, Dict, Mapping, Tuple, Type

import numpy as np
from xarray import Variable
Expand Down Expand Up @@ -47,16 +47,19 @@ class MSv2ColumnSchema:

class MainDatasetFactory:
_partition_key: PartitionKeyT
_preferred_chunks: Dict[str, int]
_table_factory: TableFactory
_structure_factory: MSv2StructureFactory

def __init__(
self,
partition_key: PartitionKeyT,
preferred_chunks: Dict[str, int],
table_factory: TableFactory,
structure_factory: MSv2StructureFactory,
):
self._partition_key = partition_key
self._preferred_chunks = preferred_chunks
self._table_factory = table_factory
self._structure_factory = structure_factory

Expand Down Expand Up @@ -103,14 +106,17 @@ def _variable_from_column(self, column: str) -> Variable:
default,
)

var = Variable(dims, data)
var = Variable(dims, data, fastpath=True)

# Apply any measures encoding
if schema.coder:
coder = schema.coder(schema.name, structure.column_descs["MAIN"])
var = coder.decode(var)

dims, data, attrs, encoding = unpack_for_decoding(var)

encoding["preferred_chunks"] = self._preferred_chunks

return Variable(dims, LazilyIndexedArray(data), attrs, encoding, fastpath=True)

def get_variables(self) -> Mapping[str, Variable]:
Expand Down Expand Up @@ -164,7 +170,8 @@ def get_variables(self) -> Mapping[str, Variable]:
("polarization", (("polarization",), partition.corr_type, None)),
]

coordinates = [(n, Variable(d, v, a)) for n, (d, v, a) in coordinates]
e = {"preferred_chunks": self._preferred_chunks}
coordinates = [(n, Variable(d, v, a, e)) for n, (d, v, a) in coordinates]

# Add time coordinate
time_coder = TimeCoder("TIME", structure.column_descs["MAIN"])
Expand Down

0 comments on commit 8f5a935

Please sign in to comment.