Skip to content

Commit

Permalink
Renamed optika.systems.AbstractSystem.__call__() to image() for c…
Browse files Browse the repository at this point in the history
…larity and added `axis_wavelength`, `axis_field`, and `axis_pupil` arguments. (#105)
  • Loading branch information
byrdie authored Nov 9, 2024
1 parent 62bf58c commit 5ad11f7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 33 deletions.
4 changes: 2 additions & 2 deletions optika/_tests/test_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ class AbstractTestAbstractSystem(
)
],
)
def test__call__(
def test_image(
self,
a: optika.systems.AbstractSystem,
scene: na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar],
):
result = a(scene)
result = a.image(scene)
assert isinstance(result, na.FunctionArray)
assert isinstance(result.inputs, na.SpectralPositionalVectorArray)
assert isinstance(result.outputs, na.AbstractScalar)
Expand Down
106 changes: 76 additions & 30 deletions optika/systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AbstractSystem(
"""

@abc.abstractmethod
def __call__(
def image(
self,
scene: na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar],
**kwargs: Any,
Expand All @@ -48,7 +48,7 @@ def __call__(
----------
scene
The spectral radiance of the scene as a function of wavelength
and field position
and field position.
kwargs
Additional keyword arguments used by subclass implementations
of this method.
Expand Down Expand Up @@ -900,6 +900,9 @@ def _rayfunction_from_vertices(
area_field = field.volume_cell(axis_field)
area_pupil = optika.direction(pupil).solid_angle_cell(axis_pupil)

area_field = np.abs(area_field)
area_pupil = np.abs(area_pupil)

area_field = area_field.cell_centers(
axis=axis_wavelength,
)
Expand Down Expand Up @@ -947,12 +950,49 @@ def _rayfunction_from_vertices(
normalized_pupil=False,
)

def __call__(
def image(
self,
scene: na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar],
grid_pupil: None | na.AbstractCartesian2dVectorArray = None,
pupil: None | na.AbstractCartesian2dVectorArray = None,
axis_wavelength: None | str = None,
axis_field: None | tuple[str, str] = None,
axis_pupil: None | tuple[str, str] = None,
**kwargs,
) -> na.FunctionArray[na.SpectralPositionalVectorArray, na.AbstractScalar]:
"""
Forward model of the optical system.
Maps the given spectral radiance of a scene to detector counts.
Parameters
----------
scene
The spectral radiance of the scene as a function of wavelength
and field position.
The inputs must be cell vertices.
pupil
An optional grid of pupil positions to use when simulating the
optical system.
Must be evaluated on cell vertices.
If :obj:`None`, ``self.grid_input.pupil`` is used.
axis_wavelength
The logical axis corresponding to changing wavelength coordinate.
If :obj:`None`,
``set(scene.inputs.wavelength.shape) - set(self.shape)``,
should have only one element.
axis_field
The two logical axes corresponding to changing field coordinate.
If :obj:`None`,
``set(scene.inputs.field.shape) - set(self.shape) - {axis_wavelength}``,
should have exactly two elements.
axis_pupil
The two logical axes corresponding to changing pupil coordinate.
If :obj:`None`,
``set(pupil.shape) - set(self.shape) - {axis_wavelength,} - set(axis_field)``,
should have exactly two elements.
kwargs
Additional keyword arguments used by subclass implementations
of this method.
"""

shape = self.shape

Expand All @@ -961,7 +1001,6 @@ def __call__(
wavelength = scene.inputs.wavelength
field = scene.inputs.position

pupil = grid_pupil
if pupil is None:
pupil = self.grid_input.pupil

Expand All @@ -971,29 +1010,36 @@ def __call__(
normalized_field = unit_field.is_equivalent(u.dimensionless_unscaled)
normalized_pupil = unit_pupil.is_equivalent(u.dimensionless_unscaled)

shape_wavelength = na.broadcast_shapes(shape, wavelength.shape)
shape_field = na.broadcast_shapes(shape, field.shape)
shape_pupil = na.broadcast_shapes(shape, pupil.shape)

shape_wavelength = {
axis: shape_wavelength[axis]
for axis in shape_wavelength
if axis not in shape
}
shape_field = {
axis: shape_field[axis]
for axis in shape_field
if axis not in shape | shape_wavelength
}
shape_pupil = {
axis: shape_pupil[axis]
for axis in shape_pupil
if axis not in shape | shape_wavelength | shape_field
}

(axis_wavelength,) = tuple(shape_wavelength)
axis_field = tuple(shape_field)
axis_pupil = tuple(shape_pupil)
if axis_wavelength is None:
axis_wavelength = set(wavelength.shape) - set(shape)
axis_wavelength = tuple(axis_wavelength)
if len(axis_wavelength) != 1: # pragma: nocover
raise ValueError(
"if `axis_wavelength` is `None`, "
f"the wavelength axis must be unambiguous, "
f"got {axis_wavelength} as possibilities."
)
(axis_wavelength,) = axis_wavelength

if axis_field is None:
axis_field = set(field.shape) - set(shape)
axis_field = tuple(axis_field - {axis_wavelength})
if len(axis_field) != 2: # pragma: nocover
raise ValueError(
"if `axis_field` is `None`, "
"the two field axes must be unambiguous, "
f"got {axis_field} as possibilities."
)

if axis_pupil is None:
axis_pupil = set(pupil.shape) - set(shape)
axis_pupil = tuple(axis_pupil - {axis_wavelength} - set(axis_field))
if len(axis_pupil) != 2: # pragma: nocover
raise ValueError(
"if `axis_pupil` is `None`, "
"the two pupil axes must be unambiguous, "
f"got {axis_pupil} as possibilities."
)

rayfunction = self._rayfunction_from_vertices(
radiance=scene.outputs,
Expand All @@ -1009,7 +1055,7 @@ def __call__(

return self.sensor.readout(
rays=rayfunction.outputs,
axis=tuple(shape_field | shape_pupil),
axis=axis_field + axis_pupil,
)

def plot(
Expand Down Expand Up @@ -1265,7 +1311,7 @@ class SequentialSystem(
)
# Simulate an image of the scene using the optical system
image = system(scene)
image = system.image(scene)
# Plot the original scene and the simulated image
with astropy.visualization.quantity_support():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
"Programming Language :: Python :: 3",
]
dependencies = [
"astropy",
"astropy!=6.1.5",
"named-arrays==0.16.0",
"svglib",
"rlPyCairo",
Expand Down

0 comments on commit 5ad11f7

Please sign in to comment.