Skip to content

Commit

Permalink
Test irregular grid correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Oct 16, 2024
1 parent 3a8637f commit ef3ca24
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
72 changes: 71 additions & 1 deletion tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
],
indirect=True,
)
def test_basic_read(simmed_ms):
def test_regular_read(simmed_ms):
"""Test for ramp function values produced by simulator"""
xdt = open_datatree(simmed_ms)

Expand All @@ -33,3 +33,73 @@ def test_basic_read(simmed_ms):
nelements = reduce(mul, uvw.shape, 1)
expected = np.arange(nelements, dtype=np.float64).reshape(uvw.shape)
assert_array_equal(uvw, expected)


ANT1_SUBSET = [0, 0, 1]
ANT2_SUBSET = [0, 1, 2]


def _select_rows(antenna1, antenna2, ant1_subset, ant2_subset):
dtype = [("a1", antenna1.dtype), ("a2", antenna2.dtype)]
baselines = np.rec.fromarrays([antenna1, antenna2], dtype=dtype)
desired = np.rec.fromarrays([ant1_subset, ant2_subset], dtype=dtype)
return np.isin(baselines, desired)


def _excise_rows(data_dict):
_, ant1 = data_dict["ANTENNA1"]
_, ant2 = data_dict["ANTENNA2"]
index = _select_rows(ant1, ant2, ANT1_SUBSET, ANT2_SUBSET)
return {k: (d, v[index]) for k, (d, v) in data_dict.items()}


@pytest.mark.parametrize(
"simmed_ms",
[
{
"name": "backend.ms",
"nantenna": 3,
"data_description": [(8, ["XX", "XY", "YX", "YY"]), (4, ["RR", "LL"])],
"transform_data": _excise_rows,
}
],
indirect=True,
)
def test_irregular_read(simmed_ms):
xdt = open_datatree(simmed_ms)

for node in xdt.subtree:
if "data_description_id" in node.attrs:
bl_index = _select_rows(
node.baseline_antenna1_name.values,
node.baseline_antenna2_name.values,
[f"ANTENNA-{i}" for i in ANT1_SUBSET],
[f"ANTENNA-{i}" for i in ANT2_SUBSET],
)

vis = node.VISIBILITY.values
# Selected baseline elements are as expected
nelements = reduce(mul, vis.shape, 1)
expected = np.arange(nelements, dtype=np.float32)
expected = (expected + expected * 1j).reshape(vis.shape)
assert_array_equal(vis[:, bl_index], expected[:, bl_index])
# Other baseline elements are nan
vis = node.VISIBILITY.values
assert np.all(np.isnan((vis[:, ~bl_index])))

uvw = node.UVW.values
# Selected baseline elements are as expected
nelements = reduce(mul, uvw.shape, 1)
expected = np.arange(nelements, dtype=np.float64).reshape(uvw.shape)
assert_array_equal(uvw[:, bl_index], expected[:, bl_index])
# Other baseline elements are nan
assert np.all(np.isnan((uvw[:, ~bl_index, ...])))

flag = node.FLAG.values
# Selected baseline elements are as expected
nelements = reduce(mul, flag.shape, 1)
expected = np.where(np.arange(nelements) & 0x1, 0, 1)
expected = expected.reshape(flag.shape)
assert_array_equal(flag[:, bl_index], expected[:, bl_index])
# Other baseline elements are flagged
assert np.all(flag[:, ~bl_index, ...] == 1)
10 changes: 7 additions & 3 deletions xarray_ms/testing/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import tempfile
import typing
from collections.abc import Callable
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -88,6 +89,7 @@ class PartitionDescriptor:


DDIDArgType = List[Tuple[npt.NDArray[np.float64], List[str]]]
PartitionDataType = Dict[str, Tuple[Tuple[str, ...], npt.NDArray]]


class MSStructureSimulator:
Expand Down Expand Up @@ -124,6 +126,7 @@ def __init__(
partition: Tuple[str, ...] = ("PROCESSOR_ID", "FIELD_ID", "DATA_DESC_ID"),
auto_corrs: bool = True,
simulate_data: bool = True,
transform_data: Callable[[PartitionDataType], PartitionDataType] | None = None,
):
assert ntime >= 1
assert time_chunks > 0
Expand Down Expand Up @@ -178,6 +181,7 @@ def __init__(
self.simulate_data = simulate_data
self.partition_names = cbp_names
self.partition_indices = bcbp_indices
self.transform_data = transform_data
self.model = {
"data_description": self.data_description,
"feed_map": self.feeds,
Expand All @@ -199,6 +203,8 @@ def simulate_ms(self, output_ms: str) -> None:

for chunk_desc in self.generate_descriptors():
data_dict = self.data_factory(chunk_desc)
if self.transform_data is not None:
data_dict = self.transform_data(data_dict)
(nrow,) = data_dict["TIME"][1].shape
T.addrows(nrow)

Expand Down Expand Up @@ -311,9 +317,7 @@ def broadcast_partition_indices(
return np.stack([a.ravel() for a in np.broadcast_arrays(*bparts)], axis=1)

@staticmethod
def data_factory(
desc: PartitionDescriptor,
) -> Dict[str, Tuple[Tuple[str, ...], npt.NDArray]]:
def data_factory(desc: PartitionDescriptor) -> PartitionDataType:
"""Creates simulated MS data from a partition descriptor"""
try:
ddid = desc.DATA_DESC_ID.item()
Expand Down

0 comments on commit ef3ca24

Please sign in to comment.