Skip to content

Commit

Permalink
First prototype of stateful dectoraor
Browse files Browse the repository at this point in the history
  • Loading branch information
andykee committed Apr 1, 2024
1 parent e34e368 commit eacfc82
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 34 deletions.
3 changes: 2 additions & 1 deletion lentil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
Grism,
LensletArray,
Rotate,
Flip
Flip,
dependent
)

from lentil.propagate import (
Expand Down
90 changes: 57 additions & 33 deletions lentil/plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@


class Plane:
"""
Base class for representing a finite geometric plane.
"""Base class for representing a finite geometric plane.
Parameters
----------
Expand Down Expand Up @@ -72,6 +71,7 @@ def __init__(self, amplitude=1, opd=0, mask=None, pixelscale=None, diameter=None
self._ptype = lentil.ptype(ptype)

self.tilt = []
self._tilt_opd = 0

def __repr__(self):
return f'{self.__class__.__name__}()'
Expand Down Expand Up @@ -130,8 +130,7 @@ def mask(self):

@property
def global_mask(self):
"""
Flattened view of :attr:`mask`
"""Flattened view of :attr:`mask`
Returns
-------
Expand Down Expand Up @@ -182,8 +181,7 @@ def diameter(self):

@property
def shape(self):
"""
Plane dimensions computed from :attr:`mask`.
"""Plane dimensions computed from :attr:`mask`.
Returns (mask.shape[1], mask.shape[2]) if :attr:`size: > 1. Returns
None if :attr:`mask` is None.
Expand All @@ -199,8 +197,7 @@ def shape(self):

@property
def size(self):
"""
Number of independent masks (segments) in :attr:`mask`
"""Number of independent masks (segments) in :attr:`mask`
Returns
-------
Expand All @@ -214,8 +211,7 @@ def size(self):

@property
def ptt_vector(self):
"""
2D vector representing piston and tilt in x and y.
"""2D vector representing piston and tilt in x and y.
Planes with no mask have :attr:`ptt_vector` = None.
Expand Down Expand Up @@ -253,8 +249,7 @@ def ptt_vector(self):
return ptt_vector

def copy(self):
"""
Make a copy of this object.
"""Make a copy of this object.
Returns
-------
Expand All @@ -263,8 +258,7 @@ def copy(self):
return copy.deepcopy(self)

def fit_tilt(self, inplace=False):
"""
Fit and remove tilt from Plane :attr:`opd` via least squares. The
"""Fit and remove tilt from Plane :attr:`opd` via least squares. The
equivalent angular tilt is bookkept in Plane :attr:`tilt`.
Parameters
Expand All @@ -291,23 +285,29 @@ def fit_tilt(self, inplace=False):

if self.size == 1:
t = np.linalg.lstsq(ptt_vector.T, plane.opd.ravel(), rcond=None)[0]
opd_tilt = np.einsum('ij,i->j', ptt_vector[1:3], t[1:3])
plane.opd -= opd_tilt.reshape(plane.opd.shape)
tilt_opd = np.einsum('ij,i->j', ptt_vector[1:3], t[1:3])
tilt_opd = tilt_opd.reshape(plane.opd.shape)
plane.tilt.append(Tilt(x=t[1], y=t[2]))

else:
t = np.empty((self.size, 3))
opd_no_tilt = np.empty((self.size, plane.opd.shape[0], plane.opd.shape[1]))
#opd_no_tilt = np.empty((self.size, plane.opd.shape[0], plane.opd.shape[1]))

# iterate over the segments and compute the tilt term
for seg in np.arange(self.size):
t[seg] = np.linalg.lstsq(ptt_vector[3 * seg:3 * seg + 3].T, plane.opd.ravel(),
rcond=None)[0]
seg_tilt = np.einsum('ij,i->j', ptt_vector[3 * seg + 1:3 * seg + 3], t[seg, 1:3])
opd_no_tilt[seg] = (plane.opd - seg_tilt.reshape(plane.opd.shape)) * self.mask[seg]
tilt_opd += seg_tilt.reshape(plane.opd.shape) * plane.mask[seg]
#opd_no_tilt[seg] = (plane.opd - seg_tilt.reshape(plane.opd.shape)) * self.mask[seg]

plane.opd = np.sum(opd_no_tilt, axis=0)
#plane.opd = np.sum(opd_no_tilt, axis=0)
plane.tilt.extend([Tilt(x=t[seg, 1], y=t[seg, 2]) for seg in range(self.size)])

if isinstance(getattr(type(self), 'opd'), lentil.dependent):
plane._tilt_opd = tilt_opd
else:
plane.opd -= tilt_opd

return plane

Expand Down Expand Up @@ -540,7 +540,6 @@ def _plane_slice(mask):
--------
helper.boundary_slice
Plane.slice_offset
"""

# self.mask may still return None so we catch that here
Expand Down Expand Up @@ -594,7 +593,6 @@ class Pupil(Plane):
aberrations in the optical system appear as deviations from this perfect
sphere. The primary use of :class:`Pupil` is to represent this spherical
wavefront.
"""

def __init__(self, amplitude=1, opd=0, mask=None, pixelscale=None,
Expand Down Expand Up @@ -683,8 +681,8 @@ class TiltInterface(Plane):
--------
Tilt
DispersiveTilt
"""

def __init__(self, **kwargs):
# if ptype is provided as a kwarg use that, otherwise default
# to lentil.tilt
Expand Down Expand Up @@ -733,7 +731,6 @@ class Tilt(TiltInterface):
Radians of tilt about the x-axis
y : float
Radians of tilt about the y-axis
"""
def __init__(self, x, y, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -819,8 +816,8 @@ class DispersiveTilt(TiltInterface):
speed and precision of modeling grisms with higher order trace and/or
dispersion functions. In cases where speed or accuracy are extremely important,
a custom solution may be required.
"""

def __init__(self, trace, dispersion, **kwargs):
super().__init__(**kwargs)

Expand Down Expand Up @@ -858,8 +855,8 @@ def _dispersion(self, wavelength):
-------
distance: float
Distance along spectral trace relative to self.dispersion[-1]
"""

if self._dispersion_order == 1:
# For a first order polynomial we have lambda = dispersion[0] * dist + dispersion[1]
# Solving for distance gives dist = (lambda - dispersion[1])/dispersion[0]
Expand Down Expand Up @@ -922,8 +919,8 @@ def _arc_len(dist_func, a, b):
References
----------
https://en.wikipedia.org/wiki/Arc_length#Finding_arc_lengths_by_integrating
"""

return scipy.integrate.quad(dist_func, a, b)[0]


Expand Down Expand Up @@ -998,8 +995,8 @@ class Grism(DispersiveTilt):
.. deprecated:: 1.0.0
`Grism` will be removed in Lentil v1.0.0, it is replaced by
`DispersiveTilt`.
"""

def __init__(self, trace, dispersion, **kwargs):
warn('lentil.Grism will be deprecated in v1.0.0, it is '
'replaced by lentil.DispersiveTilt.',
Expand Down Expand Up @@ -1030,8 +1027,8 @@ class Rotate(Plane):
If the angle is an even multiple of 90 degrees, ``numpy.rot90`` is used to
perform the rotation rather than ``scipy.ndimage.rotate``. In this case,
the order parameter is irrelevant because no interpolation occurs.
"""

def __init__(self, angle=0, unit='degrees', order=3):
super().__init__()

Expand All @@ -1052,7 +1049,6 @@ def multiply(self, wavefront):
-------
wavefront : :class:`~lentil.wavefront.Wavefront` object
Updated wavefront
"""

pixelscale = lentil.field.multiply_pixelscale(self.pixelscale, wavefront.pixelscale)
Expand Down Expand Up @@ -1085,8 +1081,8 @@ class Flip(Plane):
Axis or axes along which to flip over. The default, axis=None, will
flip over all of the axes of the input array. If axis is negative it
counts from the last to the first axis.
"""

def __init__(self, axis=None):
super().__init__()
self.axis = axis
Expand All @@ -1103,8 +1099,8 @@ def multiply(self, wavefront):
-------
wavefront : :class:`~lentil.wavefront.Wavefront` object
Updated wavefront
"""

out = wavefront.copy()
for field in out.data:
field.data = np.flip(field.data, axis=self.axis)
Expand All @@ -1114,18 +1110,46 @@ def multiply(self, wavefront):
class Quadratic(Plane):
"""Base class for representing an optical plane with a quadratic phase
term.
"""
pass


class Conic(Plane):
"""Base class for representing an optical plane with a conic phase term.
"""
pass


class dependent:
"""Decorate a :class:`~lentil.Plane` ``opd`` method to mark its return
value as being calculated rather than static.
"""

def __init__(self, func):
if func.__name__ != 'opd':
raise AttributeError('dependent can only decorate opd method')
self.func = func

def __get__(self, plane, type=None):

# This if statement allows us to access the descriptor
# from within Plane via getattr(type(self), 'opd'). By checking
# the type of the return, we know if we have a regular attribute
# (primitive type), a property, or a dependent instance.
# https://stackoverflow.com/a/21629855
if plane is None:
return self

opd = self.func(plane)
return opd - plane._tilt

def __set__(self, *args, **kwargs):
raise AttributeError('can\'t set dependent opd')

def setter(self, *args, **kwargs):
raise AttributeError('can\'t set dependent opd')


# def Q(self, wave, pixelscale, oversample=1):
# return (self.f_number*wave*oversample)/pixelscale
#
Expand Down

0 comments on commit eacfc82

Please sign in to comment.