Skip to content

Commit

Permalink
Add protections against duplicate frames (#545)
Browse files Browse the repository at this point in the history
Co-authored-by: Nadia Dencheva <[email protected]>
  • Loading branch information
WilliamJamieson and nden authored Jan 31, 2025
1 parent 1c5acaf commit a3d1cc0
Show file tree
Hide file tree
Showing 15 changed files with 1,374 additions and 1,025 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

- Implement code linting and automatic formatting. [#544]

- Refactor ``WCS`` to use a ``Pipeline`` base class which adds basic checks to ensure that the pipeline is valid. These
include checking for duplicate frame names and that the last transform is ``None``. [#545]


0.22.0 (2024-12-19)
-------------------
Expand Down
10 changes: 2 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import importlib.metadata

try:
from sphinx_astropy.conf.v1 import * # noqa: F403
from sphinx_astropy.conf.v2 import * # noqa: F403
except ImportError:
print( # noqa: T201
"ERROR: the documentation requires the sphinx-astropy package to be installed"
Expand Down Expand Up @@ -108,13 +108,6 @@
# name of a builtin theme or the name of a custom theme in html_theme_path.
# html_theme = None

# See sphinx-bootstrap-theme for documentation of these options
# https://github.com/ryan-roemer/sphinx-bootstrap-theme
html_theme_options = {
"logotext1": "g", # white, semi-bold
"logotext2": "wcs", # orange, light
"logotext3": ":docs", # white, light
}

# Custom sidebar templates, maps document names to template names.
# html_sidebars = {}
Expand Down Expand Up @@ -156,6 +149,7 @@
nitpicky = True
nitpick_ignore = [
("py:class", "gwcs.api.GWCSAPIMixin"),
("py:class", "gwcs.wcs._pipeline.Pipeline"),
("py:obj", "astropy.modeling.projections.projcodes"),
("py:attr", "gwcs.WCS.bounding_box"),
("py:meth", "gwcs.WCS.footprint"),
Expand Down
10 changes: 10 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,18 @@ Reference/API
-------------

.. automodapi:: gwcs.wcs
:inherited-members:

.. automodapi:: gwcs.coordinate_frames
:inherited-members:

.. automodapi:: gwcs.wcstools

.. automodapi:: gwcs.selector
:inherited-members:

.. automodapi:: gwcs.spectroscopy
:inherited-members:

.. automodapi:: gwcs.geometry
:inherited-members:
43 changes: 22 additions & 21 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def pixel_to_world_values(self, *pixel_arrays):
def array_index_to_world_values(self, *index_arrays):
"""
Convert array indices to world coordinates.
This is the same as `~BaseLowLevelWCS.pixel_to_world_values` except that
the indices should be given in ``(i, j)`` order, where for an image
This is the same as `~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_to_world_values`
except that the indices should be given in ``(i, j)`` order, where for an image
``i`` is the row and ``j`` is the column (i.e. the opposite order to
`~BaseLowLevelWCS.pixel_to_world_values`).
`~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_to_world_values`).
"""
pixel_arrays = index_arrays[::-1]
return self.pixel_to_world_values(*pixel_arrays)
Expand All @@ -127,11 +127,11 @@ def world_to_pixel_values(self, *world_arrays):
def world_to_array_index_values(self, *world_arrays):
"""
Convert world coordinates to array indices.
This is the same as `~BaseLowLevelWCS.world_to_pixel_values` except that
the indices should be returned in ``(i, j)`` order, where for an image
``i`` is the row and ``j`` is the column (i.e. the opposite order to
`~BaseLowLevelWCS.pixel_to_world_values`). The indices should be
returned as rounded integers.
This is the same as `~astropy.wcs.wcsapi.BaseLowLevelWCS.world_to_pixel_values`
except that the indices should be returned in ``(i, j)`` order, where for an
image ``i`` is the row and ``j`` is the column (i.e. the opposite order to
`~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_to_world_values`). The indices should
be returned as rounded integers.
"""
results = self.world_to_pixel_values(*world_arrays)
results = (results,) if self.pixel_n_dim == 1 else results[::-1]
Expand All @@ -143,7 +143,7 @@ def world_to_array_index_values(self, *world_arrays):
def array_shape(self):
"""
The shape of the data that the WCS applies to as a tuple of
length `~BaseLowLevelWCS.pixel_n_dim`.
length `~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_n_dim`.
If the WCS is valid in the context of a dataset with a particular
shape, then this property can be used to store the shape of the
data. This can be used for example if implementing slicing of WCS
Expand All @@ -167,12 +167,13 @@ def array_shape(self, value):
def pixel_bounds(self):
"""
The bounds (in pixel coordinates) inside which the WCS is defined,
as a list with `~BaseLowLevelWCS.pixel_n_dim` ``(min, max)`` tuples.
The bounds should be given in ``[(xmin, xmax), (ymin, ymax)]``
order. WCS solutions are sometimes only guaranteed to be accurate
within a certain range of pixel values, for example when defining a
WCS that includes fitted distortions. This is an optional property,
and it should return `None` if a shape is not known or relevant.
as a list with `~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_n_dim`
``(min, max)`` tuples. The bounds should be given in
``[(xmin, xmax), (ymin, ymax)]`` order. WCS solutions are sometimes
only guaranteed to be accurate within a certain range of pixel values,
for example when defining a WCS that includes fitted distortions. This
is an optional property, and it should return `None` if a shape is not
known or relevant.
"""
bounding_box = self.bounding_box
if bounding_box is None:
Expand Down Expand Up @@ -225,12 +226,12 @@ def pixel_shape(self, value):
@property
def axis_correlation_matrix(self):
"""
Returns an (`~BaseLowLevelWCS.world_n_dim`,
`~BaseLowLevelWCS.pixel_n_dim`) matrix that indicates using booleans
whether a given world coordinate depends on a given pixel coordinate.
This defaults to a matrix where all elements are `True` in the absence of
any further information. For completely independent axes, the diagonal
would be `True` and all other entries `False`.
Returns an (`~astropy.wcs.wcsapi.BaseLowLevelWCS.world_n_dim`,
`~astropy.wcs.wcsapi.BaseLowLevelWCS.pixel_n_dim`) matrix that indicates
using booleans whether a given world coordinate depends on a given pixel
coordinate. This defaults to a matrix where all elements are `True` in
the absence of any further information. For completely independent axes,
the diagonal would be `True` and all other entries `False`.
"""
return separable.separability_matrix(self.forward_transform)

Expand Down
14 changes: 8 additions & 6 deletions gwcs/converters/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ def _assert_frame_equal(a, b):
return a == b

assert a.name == b.name # nosec
assert a.axes_order == b.axes_order # nosec
assert a.axes_names == b.axes_names # nosec
assert a.unit == b.unit # nosec
assert a.reference_frame == b.reference_frame # nosec
if not isinstance(a, cf.EmptyFrame):
assert a.axes_order == b.axes_order # nosec
assert a.axes_names == b.axes_names # nosec
assert a.unit == b.unit # nosec
assert a.reference_frame == b.reference_frame # nosec
return None


Expand Down Expand Up @@ -155,12 +156,13 @@ def test_references(tmp_path):
m1 = models.Shift(12.4) & models.Shift(-2)
icrs = cf.CelestialFrame(name="icrs", reference_frame=coord.ICRS())
det = cf.Frame2D(name="detector", axes_order=(0, 1))
det2 = cf.Frame2D(name="detector2", axes_order=(0, 1))
focal = cf.Frame2D(name="focal", axes_order=(0, 1))

pipe1 = [(det, m1), (focal, m1), (icrs, None)]
gw1 = wcs.WCS(pipe1)

pipe2 = [(det, m1), (det, m1), (icrs, None)]
pipe2 = [(det, m1), (det2, m1), (icrs, None)]
gw2 = wcs.WCS(pipe2)

tree = {"wcs1": gw1, "wcs2": gw2}
Expand All @@ -173,4 +175,4 @@ def test_references(tmp_path):
gw2 = af.tree["wcs2"]
assert gw1.pipeline[0].transform is gw1.pipeline[1].transform
assert gw2.pipeline[0].transform is gw2.pipeline[1].transform
assert gw2.pipeline[0].frame is gw2.pipeline[1].frame
assert gw2.pipeline[0].frame is gw1.pipeline[0].frame
9 changes: 8 additions & 1 deletion gwcs/converters/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ def from_yaml_tree(self, node, tag, ctx):
return Step(frame=node["frame"], transform=node.get("transform", None))

def to_yaml_tree(self, step, tag, ctx):
return {"frame": step.frame, "transform": step.transform}
from gwcs.coordinate_frames import EmptyFrame

return {
"frame": step.frame.name
if isinstance(step.frame, EmptyFrame)
else step.frame,
"transform": step.transform,
}


class FrameConverter(Converter):
Expand Down
75 changes: 75 additions & 0 deletions gwcs/coordinate_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
"CelestialFrame",
"CompositeFrame",
"CoordinateFrame",
"EmptyFrame",
"Frame2D",
"SpectralFrame",
"StokesFrame",
Expand Down Expand Up @@ -608,6 +609,80 @@ def from_high_level_coordinates(self, *high_level_coords):
return values


class EmptyFrame(CoordinateFrame):
"""
Represents a "default" detector frame. This is for use as the default value
for input frame by the WCS object.
"""

def __init__(self, name=None):
self._name = "detector" if name is None else name

def __repr__(self):
return f'<{type(self).__name__}(name="{self.name}")>'

def __str__(self):
if self._name is not None:
return self._name
return type(self).__name__

@property
def name(self):
"""A custom name of this frame."""
return self._name

@name.setter
def name(self, val):
"""A custom name of this frame."""
self._name = val

def _raise_error(self) -> None:
msg = "EmptyFrame does not have any information"
raise NotImplementedError(msg)

@property
def naxes(self):
self._raise_error()

@property
def unit(self):
self._raise_error()

@property
def axes_names(self):
self._raise_error()

@property
def axes_order(self):
self._raise_error()

@property
def reference_frame(self):
self._raise_error()

@property
def axes_type(self):
self._raise_error()

@property
def axis_physical_types(self):
self._raise_error()

@property
def world_axis_object_classes(self):
self._raise_error()

@property
def _native_world_axis_object_components(self):
self._raise_error()

def to_high_level_coordinates(self, *values):
self._raise_error()

def from_high_level_coordinates(self, *high_level_coords):
self._raise_error()


class CelestialFrame(CoordinateFrame):
"""
Representation of a Celesital coordinate system.
Expand Down
46 changes: 35 additions & 11 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
stokes = cf.StokesFrame(axes_order=(2,))

pipe = [wcs.Step(detector, m1), wcs.Step(focal, m2), wcs.Step(icrs, None)]
pipe_copy = pipe.copy()

# Create some data.
nx, ny = (5, 2)
Expand Down Expand Up @@ -104,28 +105,28 @@ def test_init_no_transform():
"""
gw = wcs.WCS(output_frame="icrs")
assert len(gw._pipeline) == 2
assert gw.pipeline[0].frame == "detector"
assert gw.pipeline[0].frame.name == "detector"
with pytest.warns(
DeprecationWarning, match="Indexing a WCS.pipeline step is deprecated."
):
assert gw.pipeline[0][0] == "detector"
assert gw.pipeline[1].frame == "icrs"
assert gw.pipeline[0][0].name == "detector"
assert gw.pipeline[1].frame.name == "icrs"
with pytest.warns(
DeprecationWarning, match="Indexing a WCS.pipeline step is deprecated."
):
assert gw.pipeline[1][0] == "icrs"
assert gw.pipeline[1][0].name == "icrs"
assert np.isin(gw.available_frames, ["detector", "icrs"]).all()
gw = wcs.WCS(output_frame=icrs, input_frame=detector)
assert gw._pipeline[0].frame == "detector"
assert gw._pipeline[0].frame.name == "detector"
with pytest.warns(
DeprecationWarning, match="Indexing a WCS.pipeline step is deprecated."
):
assert gw._pipeline[0][0] == "detector"
assert gw._pipeline[1].frame == "icrs"
assert gw._pipeline[0][0].name == "detector"
assert gw._pipeline[1].frame.name == "icrs"
with pytest.warns(
DeprecationWarning, match="Indexing a WCS.pipeline step is deprecated."
):
assert gw._pipeline[1][0] == "icrs"
assert gw._pipeline[1][0].name == "icrs"
assert np.isin(gw.available_frames, ["detector", "icrs"]).all()
with pytest.raises(NotImplementedError):
gw(1, 2)
Expand Down Expand Up @@ -732,7 +733,7 @@ def test_units(self):
assert self.wcs.unit == (u.degree, u.degree)

def test_get_transform(self):
with pytest.raises(wcs.CoordinateFrameError):
with pytest.raises(CoordinateFrameError):
assert (
self.wcs.get_transform("x_translation", "sky_rotation").submodel_names
== self.wcs.forward_transform[1:].submodel_names
Expand Down Expand Up @@ -1385,8 +1386,8 @@ def test_initialize_wcs_with_list():
shift2 = models.Shift(3 * u.pix)
pipeline = [("detector", shift1), wcs.Step("extra_step", shift2)]

extra_step = ("extra_step", None)
pipeline.append(extra_step)
end_step = ("end_step", None)
pipeline.append(end_step)

# make sure no warnings occur when creating wcs with this pipeline
with warnings.catch_warnings():
Expand Down Expand Up @@ -1735,3 +1736,26 @@ def test_high_level_objects_in_pipeline_backward(gwcs_with_pipeline_celestial):
with_units=True,
)
assert isinstance(intermediate_world, coord.SkyCoord)


def test_error_with_duplicate_frames():
"""
Test that an error is raised if a frame is used more than once in the pipeline.
"""
pipeline = [(detector, m1), (detector, m2), (focal, None)]

with pytest.raises(ValueError, match="Frame detector is already in the pipeline."):
wcs.WCS(pipeline)


def test_error_with_not_none_last():
"""
Test that an error is raised if the last transform is not None
"""

pipeline = [(detector, m1), (focal, m2)]

with pytest.raises(
ValueError, match="The last step in the pipeline must have a None transform."
):
wcs.WCS(pipeline)
10 changes: 10 additions & 0 deletions gwcs/wcs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ._exception import GwcsBoundingBoxWarning, NoConvergence
from ._step import Step
from ._wcs import WCS

__all__ = [
"WCS",
"GwcsBoundingBoxWarning",
"NoConvergence",
"Step",
]
Loading

0 comments on commit a3d1cc0

Please sign in to comment.