Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make partitioning columns configurable #3

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ def test_baseline_id(na, auto_corrs):

@pytest.mark.parametrize("simmed_ms", [{"name": "proxy.ms"}], indirect=True)
def test_structure_factory(simmed_ms):
partition_columns = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID"]
table_factory = TableFactory(Table.from_filename, simmed_ms)
structure_factory = MSv2StructureFactory(table_factory)
structure_factory = MSv2StructureFactory(table_factory, partition_columns)
assert pickle.loads(pickle.dumps(structure_factory)) == structure_factory

structure_factory2 = MSv2StructureFactory(table_factory)
structure_factory2 = MSv2StructureFactory(table_factory, partition_columns)
assert structure_factory() is structure_factory2()

keys = tuple(k for kv in structure_factory().keys() for k, _ in kv)
assert tuple(sorted(partition_columns)) == keys
127 changes: 83 additions & 44 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import warnings
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, Iterable
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple
from uuid import uuid4

import xarray
Expand Down Expand Up @@ -33,19 +33,11 @@
from xarray_ms.backend.msv2.structure import PartitionKeyT


def table_factory_factory(ms: str, ninstances: int) -> TableFactory:
"""
Ensures consistency when creating a TableFactory.
Multiple calls to this method with the same argument will
resolve to the same cached instance.
"""
return TableFactory(
Table.from_filename,
ms,
ninstances=ninstances,
readonly=True,
lockoptions="nolock",
)
DEFAULT_PARTITION_COLUMNS: List[str] = [
"DATA_DESC_ID",
"FIELD_ID",
"OBSERVATION_ID",
]


def promote_chunks(
Expand Down Expand Up @@ -78,20 +70,62 @@ def promote_chunks(
return return_chunks


def initialise_default_args(
ms: str,
ninstances: int,
auto_corrs: bool,
epoch: str | None,
table_factory: TableFactory | None,
partition_columns: List[str] | None,
partition_key: PartitionKeyT | None,
structure_factory: MSv2StructureFactory | None,
) -> Tuple[str, TableFactory, List[str], PartitionKeyT, MSv2StructureFactory]:
"""
Ensures consistency when initialising default arguments from multiple locations
"""
if not os.path.exists(ms):
raise ValueError(f"MS {ms} does not exist")

table_factory = table_factory or TableFactory(
Table.from_filename,
ms,
ninstances=ninstances,
readonly=True,
lockoptions="nolock",
)
epoch = epoch or uuid4().hex[:8]
partition_columns = partition_columns or DEFAULT_PARTITION_COLUMNS
structure_factory = structure_factory or MSv2StructureFactory(
table_factory, partition_columns, auto_corrs=auto_corrs
)
structure = structure_factory()
if partition_key is None:
partition_key = next(iter(structure.keys()))
warnings.warn(
f"No partition_key was supplied. Selected first partition {partition_key}"
)
elif partition_key not in structure:
raise ValueError(f"{partition_key} not in {list(structure.keys())}")

return epoch, table_factory, partition_columns, partition_key, structure_factory


class MSv2Store(AbstractWritableDataStore):
"""Store for reading and writing MSv2 data"""

__slots__ = (
"_table_factory",
"_structure_factory",
"_partition",
"_partition_columns",
"_partition_key",
"_auto_corrs",
"_ninstances",
"_epoch",
)

_table_factory: TableFactory
_structure_factory: MSv2StructureFactory
_partition_columns: List[str]
_partition: PartitionKeyT
_autocorrs: bool
_ninstances: int
Expand All @@ -101,14 +135,16 @@ def __init__(
self,
table_factory: TableFactory,
structure_factory: MSv2StructureFactory,
partition: PartitionKeyT,
partition_columns: List[str],
partition_key: PartitionKeyT,
auto_corrs: bool,
ninstances: int,
epoch: str,
):
self._table_factory = table_factory
self._structure_factory = structure_factory
self._partition = partition
self._partition_columns = partition_columns
self._partition_key = partition_key
self._auto_corrs = auto_corrs
self._ninstances = ninstances
self._epoch = epoch
Expand All @@ -118,7 +154,8 @@ def open(
cls,
ms: str,
drop_variables=None,
partition: PartitionKeyT | None = None,
partition_columns: List[str] | None = None,
partition_key: PartitionKeyT | None = None,
auto_corrs: bool = True,
ninstances: int = 1,
epoch: str | None = None,
Expand All @@ -127,23 +164,24 @@ def open(
if not isinstance(ms, str):
raise ValueError("Measurement Sets paths must be strings")

table_factory = table_factory_factory(ms, ninstances)
epoch = epoch or uuid4().hex[:8]
structure_factory = structure_factory or MSv2StructureFactory(
table_factory, auto_corrs
epoch, table_factory, partition_columns, partition_key, structure_factory = (
initialise_default_args(
ms,
ninstances,
auto_corrs,
epoch,
None,
partition_columns,
partition_key,
structure_factory,
)
)
structure = structure_factory()

if partition is None:
partition = next(iter(structure.keys()))
warnings.warn(f"No partition was supplied. Selected first partition {partition}")
elif partition not in structure:
raise ValueError(f"{partition} not in {list(structure.keys())}")

return cls(
table_factory,
structure_factory,
partition=partition,
partition_columns=partition_columns,
partition_key=partition_key,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
Expand All @@ -154,12 +192,12 @@ def close(self, **kwargs):

def get_variables(self):
return MainDatasetFactory(
self._partition, self._table_factory, self._structure_factory
self._partition_key, self._table_factory, self._structure_factory
).get_variables()

def get_attrs(self):
try:
ddid = next(iter(v for k, v in self._partition if k == "DATA_DESC_ID"))
ddid = next(iter(v for k, v in self._partition_key if k == "DATA_DESC_ID"))
except StopIteration:
raise KeyError("DATA_DESC_ID not found in partition")

Expand All @@ -183,7 +221,7 @@ def get_encoding(self):
class MSv2PartitionEntryPoint(BackendEntrypoint):
open_dataset_parameters = [
"filename_or_obj",
"partition",
"partition_columns" "partition_key",
"auto_corrs",
"ninstances",
"epoch",
Expand Down Expand Up @@ -212,14 +250,11 @@ def guess_can_open(

def open_dataset(
self,
filename_or_obj: str
| os.PathLike[Any]
| BufferedIOBase
| AbstractDataStore
| TableFactory,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
drop_variables: str | Iterable[str] | None = None,
partition=None,
partition_columns=None,
partition_key=None,
auto_corrs=True,
ninstances=8,
epoch=None,
Expand All @@ -229,7 +264,8 @@ def open_dataset(
store = MSv2Store.open(
filename_or_obj,
drop_variables=drop_variables,
partition=partition,
partition_columns=partition_columns,
partition_key=partition_key,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
Expand All @@ -243,6 +279,7 @@ def open_datatree(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
drop_variables: str | Iterable[str] | None = None,
partition_columns=None,
auto_corrs=True,
ninstances=8,
epoch=None,
Expand All @@ -255,10 +292,11 @@ def open_datatree(
else:
raise ValueError("Measurement Set paths must be strings")

table_factory = table_factory_factory(ms, ninstances)
structure_factory = MSv2StructureFactory(table_factory, auto_corrs=auto_corrs)
structure = structure_factory()
epoch, _, partition_columns, _, structure_factory = initialise_default_args(
ms, ninstances, auto_corrs, epoch, None, partition_columns, None, None
)

structure = structure_factory()
datasets = {}
chunks = kwargs.pop("chunks", None)
pchunks = promote_chunks(structure, chunks)
Expand All @@ -267,7 +305,8 @@ def open_datatree(
ds = xarray.open_dataset(
ms,
drop_variables=drop_variables,
partition=partition_key,
partition_columns=partition_columns,
partition_key=partition_key,
auto_corrs=auto_corrs,
ninstances=ninstances,
epoch=epoch,
Expand Down
Loading