Skip to content

Commit

Permalink
Make preferred_chunks optional
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Nov 5, 2024
1 parent 001bd23 commit 5ca92dd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
40 changes: 35 additions & 5 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def open(
partition_key = first_key

if preferred_chunks is None:
preferred_chunks = preferred_chunks or {}
preferred_chunks = {}

return cls(
table_factory,
Expand Down Expand Up @@ -368,6 +368,36 @@ def open_datatree(
.. _preferred_chunk_sizes: https://docs.xarray.dev/en/stable/internals/how-to-add-new-backend.html#preferred-chunk-sizes
"""
groups_dict = self.open_groups_as_dict(
filename_or_obj,
drop_variables=drop_variables,
partition_columns=partition_columns,
preferred_chunks=preferred_chunks,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
**kwargs,
)

return DataTree.from_dict(groups_dict)

@format_docstring(DEFAULT_PARTITION_COLUMNS=DEFAULT_PARTITION_COLUMNS)
def open_groups_as_dict(
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
drop_variables: str | Iterable[str] | None = None,
partition_columns: List[str] | None = None,
preferred_chunks: Dict[str, int] | None = None,
auto_corrs: bool = True,
ninstances: int = 8,
epoch: str | None = None,
structure_factory: MSv2StructureFactory | None = None,
**kwargs,
) -> Dict[str, Dataset]:
"""Create a dictionary of :class:`~xarray.Dataset` presenting an MSv4 view
over a partition of a MSv2 CASA Measurement Set"""

if isinstance(filename_or_obj, os.PathLike):
ms = str(filename_or_obj)
elif isinstance(filename_or_obj, str):
Expand Down Expand Up @@ -402,8 +432,8 @@ def open_datatree(

antenna_factory = AntennaDatasetFactory(structure_factory)

key = ",".join(f"{k}={v}" for k, v in sorted(partition_key))
datasets[key] = ds
datasets[f"{key}/ANTENNA"] = antenna_factory.get_dataset()
path = ",".join(f"{k}={v}" for k, v in sorted(partition_key))
datasets[path] = ds
datasets[f"{path}/ANTENNA"] = antenna_factory.get_dataset()

return DataTree.from_dict(datasets)
return datasets
5 changes: 3 additions & 2 deletions xarray_ms/backend/msv2/main_dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def _variable_from_column(self, column: str) -> Variable:

dims, data, attrs, encoding = unpack_for_decoding(var)

encoding["preferred_chunks"] = self._preferred_chunks
if self._preferred_chunks:
encoding["preferred_chunks"] = self._preferred_chunks

return Variable(dims, LazilyIndexedArray(data), attrs, encoding, fastpath=True)

Expand Down Expand Up @@ -170,7 +171,7 @@ def get_variables(self) -> Mapping[str, Variable]:
("polarization", (("polarization",), partition.corr_type, None)),
]

e = {"preferred_chunks": self._preferred_chunks}
e = {"preferred_chunks": self._preferred_chunks} if self._preferred_chunks else None
coordinates = [(n, Variable(d, v, a, e)) for n, (d, v, a) in coordinates]

# Add time coordinate
Expand Down

0 comments on commit 5ca92dd

Please sign in to comment.