diff --git a/README.rst b/README.rst index e5bd8b0..2ad1a68 100644 --- a/README.rst +++ b/README.rst @@ -24,28 +24,28 @@ to be developed on well-understood MSv2 data. >>> import xarray_ms >>> from xarray.backends.api import datatree >>> dt = open_datatree("/data/L795830_SB001_uv.MS/", - partition_chunks={"time": 2000, "baseline": 1000}) + preferred_chunks={"time": 2000, "baseline_id": 1000}) >>> dt Group: / └── Group: /DATA_DESC_ID=0,FIELD_ID=0,OBSERVATION_ID=0 - │ Dimensions: (time: 28760, baseline: 2775, frequency: 16, + │ Dimensions: (time: 28760, baseline_id: 2775, frequency: 16, │ polarization: 4, uvw_label: 3) │ Coordinates: - │ antenna1_name (baseline) object 22kB ... - │ antenna2_name (baseline) object 22kB ... - │ baseline_id (baseline) int64 22kB ... + │ antenna1_name (baseline_id) object 22kB ... + │ antenna2_name (baseline_id) object 22kB ... + │ baseline_id (baseline_id) int64 22kB ... │ * frequency (frequency) float64 128B 1.202e+08 ... 1.204e+08 │ * polarization (polarization) = 2024.9.0 (:pr:`44`) +* Change ``partition_chunks`` to ``preferred_chunks`` (:pr:`44`) * Allow arcae to vary in the 0.2.x range (:pr:`42`) * Pin xarray to 2024.9.0 (:pr:`42`) * Add test case for irregular grids (:pr:`39`, :pr:`40`, :pr:`41`) diff --git a/doc/source/tutorial.rst b/doc/source/tutorial.rst index fa175cf..8a6b9a7 100644 --- a/doc/source/tutorial.rst +++ b/doc/source/tutorial.rst @@ -64,7 +64,7 @@ For example, one could select select some specific dimensions out: dt = open_datatree(ms, partition_columns=["DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"]) - subdt = dt.isel(time=slice(1, 3), baseline=[1, 3, 5], frequency=slice(2, 4)) + subdt = dt.isel(time=slice(1, 3), baseline_id=[1, 3, 5], frequency=slice(2, 4)) subdt At this point, the ``subdt`` DataTree is still lazy -- no Data variables have been loaded @@ -103,22 +103,22 @@ Per-partition chunking Different chunking may be desired, especially when applied to different channelisation and polarisation configurations. -In these cases, the ``partition_chunks`` argument can be used +In these cases, the ``preferred_chunks`` argument can be used to specify different chunking setups for each partition. .. ipython:: python dt = open_datatree(ms, partition_columns=[ "DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"], - partition_chunks={ + chunks={}, + preferred_chunks={ (("DATA_DESC_ID", 0),): {"time": 2, "frequency": 4}, (("DATA_DESC_ID", 1),): {"time": 3, "frequency": 2}}) -See the ``partition_chunks`` argument of -:meth:`xarray_ms.backend.msv2.entrypoint.MSv2EntryPoint.open_datatree` +See the ``preferred_chunks`` argument of +:meth:`~xarray_ms.backend.msv2.entrypoint.MSv2EntryPoint.open_datatree` for more information. - .. ipython:: python dt @@ -139,7 +139,8 @@ this to a zarr_ store. dt = open_datatree(ms, partition_columns=[ "DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"], - partition_chunks={ + chunks={}, + preferred_chunks={ (("DATA_DESC_ID", 0),): {"time": 2, "frequency": 4}, (("DATA_DESC_ID", 1),): {"time": 3, "frequency": 2}}) diff --git a/pyproject.toml b/pyproject.toml index c29047a..0b9e3a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ readme = "README.rst" [tool.poetry.dependencies] python = "^3.10" pytest = {version = "^8.0.0", optional = true, extras = ["testing"]} -xarray = "^2024.9.0, < 2024.10.0" +xarray = "^2024.9.0" dask = {version = "^2024.5.0", optional = true, extras = ["testing"]} distributed = {version = "^2024.5.0", optional = true, extras = ["testing"]} cacheout = "^0.16.0" diff --git a/tests/test_backend.py b/tests/test_backend.py index 7fb73e0..544ff7b 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -156,7 +156,7 @@ def test_open_datatree(simmed_ms): # Works with default dask scheduler with ExitStack() as stack: - dt = open_datatree(simmed_ms, partition_chunks=chunks) + dt = open_datatree(simmed_ms, preferred_chunks=chunks) for ds in dt.values(): del ds.attrs["creation_date"] xt.assert_identical(dt, mem_dt) @@ -165,7 +165,7 @@ def test_open_datatree(simmed_ms): with ExitStack() as stack: cluster = stack.enter_context(LocalCluster(processes=True, n_workers=4)) stack.enter_context(Client(cluster)) - dt = open_datatree(simmed_ms, partition_chunks=chunks) + dt = open_datatree(simmed_ms, preferred_chunks=chunks) for ds in dt.values(): del ds.attrs["creation_date"] xt.assert_identical(dt, mem_dt) @@ -186,7 +186,8 @@ def test_open_datatree_chunking(simmed_ms): and partition-specific chunking""" dt = open_datatree( simmed_ms, - partition_chunks={"time": 3, "frequency": 2}, + chunks={}, + preferred_chunks={"time": 3, "frequency": 2}, ) for child in dt.children: @@ -194,7 +195,7 @@ def test_open_datatree_chunking(simmed_ms): if ds.attrs["data_description_id"] == 0: assert dict(ds.chunks) == { "time": (3, 2), - "baseline": (6,), + "baseline_id": (6,), "frequency": (2, 2, 2, 2), "polarization": (4,), "uvw_label": (3,), @@ -202,7 +203,7 @@ def test_open_datatree_chunking(simmed_ms): elif ds.attrs["data_description_id"] == 1: assert dict(ds.chunks) == { "time": (3, 2), - "baseline": (6,), + "baseline_id": (6,), "frequency": (2, 2), "polarization": (2,), "uvw_label": (3,), @@ -210,8 +211,9 @@ def test_open_datatree_chunking(simmed_ms): dt = open_datatree( simmed_ms, - partition_chunks={ - "D=0": {"time": 2, "baseline": 2}, + chunks={}, + preferred_chunks={ + "D=0": {"time": 2, "baseline_id": 2}, "D=1": {"time": 3, "frequency": 2}, }, ) @@ -221,7 +223,7 @@ def test_open_datatree_chunking(simmed_ms): if ds.attrs["data_description_id"] == 0: assert ds.chunks == { "time": (2, 2, 1), - "baseline": (2, 2, 2), + "baseline_id": (2, 2, 2), "frequency": (8,), "polarization": (4,), "uvw_label": (3,), @@ -229,18 +231,18 @@ def test_open_datatree_chunking(simmed_ms): elif ds.attrs["data_description_id"] == 1: assert ds.chunks == { "time": (3, 2), - "baseline": (6,), + "baseline_id": (6,), "frequency": (2, 2), "polarization": (2,), "uvw_label": (3,), } - with pytest.warns(UserWarning, match="`partition_chunks` overriding `chunks`"): - dt = open_datatree( - simmed_ms, - chunks={}, - partition_chunks={ - "D=0": {"time": 2, "baseline": 2}, - "D=1": {"time": 3, "frequency": 2}, - }, - ) + # with pytest.warns(UserWarning, match="`preferred_chunks` overriding `chunks`"): + # dt = open_datatree( + # simmed_ms, + # chunks={}, + # preferred_chunks={ + # "D=0": {"time": 2, "baseline_id": 2}, + # "D=1": {"time": 3, "frequency": 2}, + # }, + # ) diff --git a/tests/test_read.py b/tests/test_read.py index 488928d..037284d 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -54,7 +54,7 @@ def _excise_rows(data_dict): @pytest.mark.filterwarnings( - r"ignore:.*?rows missing from the full \(time, baseline\) grid" + r"ignore:.*?rows missing from the full \(time, baseline_id\) grid" ) @pytest.mark.parametrize( "simmed_ms", diff --git a/xarray_ms/backend/msv2/array.py b/xarray_ms/backend/msv2/array.py index c24181a..1f8c4d0 100644 --- a/xarray_ms/backend/msv2/array.py +++ b/xarray_ms/backend/msv2/array.py @@ -57,7 +57,7 @@ def __init__( self.shape = shape self.dtype = np.dtype(dtype) - assert len(shape) >= 2, "(time, baseline) required" + assert len(shape) >= 2, "(time, baseline_ids) required" def __getitem__(self, key) -> npt.NDArray: return explicit_indexing_adapter( @@ -67,7 +67,7 @@ def __getitem__(self, key) -> npt.NDArray: def _getitem(self, key) -> npt.NDArray: assert len(key) == len(self.shape) expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape)) - # Map the (time, baseline) coordinates onto row indices + # Map the (time, baseline_id) coordinates onto row indices rows = self._structure_factory()[self._partition].row_map[key[:2]] xkey = (rows.ravel(),) + key[2:] row_shape = (rows.size,) + expected_shape[2:] diff --git a/xarray_ms/backend/msv2/entrypoint.py b/xarray_ms/backend/msv2/entrypoint.py index f1a835d..5356fdc 100644 --- a/xarray_ms/backend/msv2/entrypoint.py +++ b/xarray_ms/backend/msv2/entrypoint.py @@ -105,6 +105,7 @@ class MSv2Store(AbstractWritableDataStore): "_structure_factory", "_partition_columns", "_partition_key", + "_preferred_chunks", "_auto_corrs", "_ninstances", "_epoch", @@ -113,6 +114,7 @@ class MSv2Store(AbstractWritableDataStore): _table_factory: TableFactory _structure_factory: MSv2StructureFactory _partition_columns: List[str] + _preferred_chunks: Dict[str, int] _partition: PartitionKeyT _autocorrs: bool _ninstances: int @@ -124,6 +126,7 @@ def __init__( structure_factory: MSv2StructureFactory, partition_columns: List[str], partition_key: PartitionKeyT, + preferred_chunks: Dict[str, int], auto_corrs: bool, ninstances: int, epoch: str, @@ -132,6 +135,7 @@ def __init__( self._structure_factory = structure_factory self._partition_columns = partition_columns self._partition_key = partition_key + self._preferred_chunks = preferred_chunks self._auto_corrs = auto_corrs self._ninstances = ninstances self._epoch = epoch @@ -143,6 +147,7 @@ def open( drop_variables=None, partition_columns: List[str] | None = None, partition_key: PartitionKeyT | None = None, + preferred_chunks: Dict[str, int] | None = None, auto_corrs: bool = True, ninstances: int = 1, epoch: str | None = None, @@ -177,11 +182,15 @@ def open( ) partition_key = first_key + if preferred_chunks is None: + preferred_chunks = {} + return cls( table_factory, structure_factory, partition_columns=partition_columns, partition_key=partition_key, + preferred_chunks=preferred_chunks, auto_corrs=auto_corrs, ninstances=ninstances, epoch=epoch, @@ -191,9 +200,14 @@ def close(self, **kwargs): pass def get_variables(self): - return MainDatasetFactory( - self._partition_key, self._table_factory, self._structure_factory - ).get_variables() + factory = MainDatasetFactory( + self._partition_key, + self._preferred_chunks, + self._table_factory, + self._structure_factory, + ) + + return factory.get_variables() def get_attrs(self): try: @@ -219,6 +233,7 @@ class MSv2EntryPoint(BackendEntrypoint): "filename_or_obj", "partition_columns", "partition_key", + "preferred_chunks", "auto_corrs", "ninstances", "epoch", @@ -253,6 +268,7 @@ def open_dataset( drop_variables: str | Iterable[str] | None = None, partition_columns: List[str] | None = None, partition_key: PartitionKeyT | None = None, + preferred_chunks: Dict[str, int] | None = None, auto_corrs: bool = True, ninstances: int = 8, epoch: str | None = None, @@ -269,6 +285,7 @@ def open_dataset( partition_key: A key corresponding to an individual partition. For example :code:`(('DATA_DESC_ID', 0), ('FIELD_ID', 0))`. If :code:`None`, the first partition will be opened. + preferred_chunks: The preferred chunks for each partition. auto_corrs: Include/Exclude auto-correlations. ninstances: The number of Measurement Set instances to open for parallel I/O. epoch: A unique string identifying the creation of this Dataset. @@ -281,11 +298,13 @@ def open_dataset( partition specified by :code:`partition_columns` and :code:`partition_key`. """ filename_or_obj = _normalize_path(filename_or_obj) + store = MSv2Store.open( filename_or_obj, drop_variables=drop_variables, partition_columns=partition_columns, partition_key=partition_key, + preferred_chunks=preferred_chunks, auto_corrs=auto_corrs, ninstances=ninstances, epoch=epoch, @@ -299,7 +318,7 @@ def open_datatree( self, filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, - partition_chunks: Dict[str, Any] | None = None, + preferred_chunks: Dict[str, Any] | None = None, drop_variables: str | Iterable[str] | None = None, partition_columns: List[str] | None = None, auto_corrs: bool = True, @@ -312,7 +331,7 @@ def open_datatree( Args: filename_or_obj: The path to the MSv2 CASA Measurement Set file. - partition_chunks: Chunk sizes along each dimension, + preferred_chunks: Chunk sizes along each dimension, e.g. :code:`{{"time": 10, "frequency": 16}}`. Individual partitions can be chunked differently by partially (or fully) specifying a partition key: e.g. @@ -332,10 +351,12 @@ def open_datatree( "D=0,F=1": {{"time": 20, "frequency": 32}}, }} - .. note:: This argument overrides the reserved ``chunks`` argument - used by xarray to control chunking in Datasets and DataTrees. - It should be used instead of ``chunks`` when different - chunking is desired for different partitions. + .. note:: xarray's reserved ``chunks`` argument must be specified in order + to enable this functionality and enable fine-grained chunking + in Datasets and DataTrees. + See xarray's backend documentation on + `Preferred chunk sizes `_ + for more information. drop_variables: Variables to drop from the dataset. partition_columns: The columns to use for partitioning the Measurement set. @@ -347,7 +368,39 @@ def open_datatree( Returns: An xarray :class:`~xarray.core.datatree.DataTree` + + .. _preferred_chunk_sizes: https://docs.xarray.dev/en/stable/internals/how-to-add-new-backend.html#preferred-chunk-sizes """ + groups_dict = self.open_groups_as_dict( + filename_or_obj, + drop_variables=drop_variables, + partition_columns=partition_columns, + preferred_chunks=preferred_chunks, + auto_corrs=auto_corrs, + ninstances=ninstances, + epoch=epoch, + **kwargs, + ) + + return DataTree.from_dict(groups_dict) + + @format_docstring(DEFAULT_PARTITION_COLUMNS=DEFAULT_PARTITION_COLUMNS) + def open_groups_as_dict( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, + drop_variables: str | Iterable[str] | None = None, + partition_columns: List[str] | None = None, + preferred_chunks: Dict[str, int] | None = None, + auto_corrs: bool = True, + ninstances: int = 8, + epoch: str | None = None, + structure_factory: MSv2StructureFactory | None = None, + **kwargs, + ) -> Dict[str, Dataset]: + """Create a dictionary of :class:`~xarray.Dataset` presenting an MSv4 view + over a partition of a MSv2 CASA Measurement Set""" + if isinstance(filename_or_obj, os.PathLike): ms = str(filename_or_obj) elif isinstance(filename_or_obj, str): @@ -361,14 +414,7 @@ def open_datatree( structure = structure_factory() datasets = {} - - if not partition_chunks: - partition_chunks = kwargs.pop("chunks", None) - elif "chunks" in kwargs: - kwargs.pop("chunks", None) - warnings.warn("`partition_chunks` overriding `chunks`") - - pchunks = promote_chunks(structure, partition_chunks) + pchunks = promote_chunks(structure, preferred_chunks) for partition_key in structure: ds = xarray.open_dataset( @@ -377,18 +423,20 @@ def open_datatree( engine="xarray-ms:msv2", partition_columns=partition_columns, partition_key=partition_key, + preferred_chunks=pchunks[partition_key] + if isinstance(pchunks, Mapping) + else pchunks, auto_corrs=auto_corrs, ninstances=ninstances, epoch=epoch, structure_factory=structure_factory, - chunks=pchunks[partition_key] if isinstance(pchunks, Mapping) else pchunks, **kwargs, ) antenna_factory = AntennaDatasetFactory(structure_factory) - key = ",".join(f"{k}={v}" for k, v in sorted(partition_key)) - datasets[key] = ds - datasets[f"{key}/ANTENNA"] = antenna_factory.get_dataset() + path = ",".join(f"{k}={v}" for k, v in sorted(partition_key)) + datasets[path] = ds + datasets[f"{path}/ANTENNA"] = antenna_factory.get_dataset() - return DataTree.from_dict(datasets) + return datasets diff --git a/xarray_ms/backend/msv2/main_dataset_factory.py b/xarray_ms/backend/msv2/main_dataset_factory.py index 82cfdef..89f02b6 100644 --- a/xarray_ms/backend/msv2/main_dataset_factory.py +++ b/xarray_ms/backend/msv2/main_dataset_factory.py @@ -1,6 +1,6 @@ import dataclasses import warnings -from typing import Any, Mapping, Tuple, Type +from typing import Any, Dict, Mapping, Tuple, Type import numpy as np from xarray import Variable @@ -47,16 +47,19 @@ class MSv2ColumnSchema: class MainDatasetFactory: _partition_key: PartitionKeyT + _preferred_chunks: Dict[str, int] _table_factory: TableFactory _structure_factory: MSv2StructureFactory def __init__( self, partition_key: PartitionKeyT, + preferred_chunks: Dict[str, int], table_factory: TableFactory, structure_factory: MSv2StructureFactory, ): self._partition_key = partition_key + self._preferred_chunks = preferred_chunks self._table_factory = table_factory self._structure_factory = structure_factory @@ -78,13 +81,13 @@ def _variable_from_column(self, column: str) -> Variable: dim_sizes = { "time": len(partition.time), - "baseline": structure.nbl, + "baseline_id": structure.nbl, "frequency": len(partition.chan_freq), "polarization": len(partition.corr_type), **FIXED_DIMENSION_SIZES, } - dims = ("time", "baseline") + schema.dims + dims = ("time", "baseline_id") + schema.dims try: shape = tuple(dim_sizes[d] for d in dims) @@ -103,7 +106,7 @@ def _variable_from_column(self, column: str) -> Variable: default, ) - var = Variable(dims, data) + var = Variable(dims, data, fastpath=True) # Apply any measures encoding if schema.coder: @@ -111,6 +114,10 @@ def _variable_from_column(self, column: str) -> Variable: var = coder.decode(var) dims, data, attrs, encoding = unpack_for_decoding(var) + + if self._preferred_chunks: + encoding["preferred_chunks"] = self._preferred_chunks + return Variable(dims, LazilyIndexedArray(data), attrs, encoding, fastpath=True) def get_variables(self) -> Mapping[str, Variable]: @@ -129,7 +136,7 @@ def get_variables(self) -> Mapping[str, Variable]: if missing > 0: warnings.warn( f"{missing} / {row_map.size} ({100. * missing / row_map.size:.1f}%) " - f"rows missing from the full (time, baseline) grid " + f"rows missing from the full (time, baseline_id) grid " f"in partition {self._partition_key}. " f"Dataset variables will be padded", IrregularGridWarning, @@ -151,20 +158,21 @@ def get_variables(self) -> Mapping[str, Variable]: coordinates = [ ( "baseline_id", - (("baseline",), np.arange(len(ant1)), {"coordinates": "baseline_id"}), + (("baseline_id",), np.arange(len(ant1)), {"coordinates": "baseline_id"}), ), ( "baseline_antenna1_name", - (("baseline",), ant1_names, {"coordinates": "baseline_antenna1_name"}), + (("baseline_id",), ant1_names, {"coordinates": "baseline_antenna1_name"}), ), ( "baseline_antenna2_name", - (("baseline",), ant2_names, {"coordinates": "baseline_antenna2_name"}), + (("baseline_id",), ant2_names, {"coordinates": "baseline_antenna2_name"}), ), ("polarization", (("polarization",), partition.corr_type, None)), ] - coordinates = [(n, Variable(d, v, a)) for n, (d, v, a) in coordinates] + e = {"preferred_chunks": self._preferred_chunks} if self._preferred_chunks else None + coordinates = [(n, Variable(d, v, a, e)) for n, (d, v, a) in coordinates] # Add time coordinate time_coder = TimeCoder("TIME", structure.column_descs["MAIN"])