diff --git a/lentil/__init__.py b/lentil/__init__.py index 4fa8302..c7a0852 100644 --- a/lentil/__init__.py +++ b/lentil/__init__.py @@ -29,7 +29,8 @@ Grism, LensletArray, Rotate, - Flip + Flip, + dependent ) from lentil.propagate import ( diff --git a/lentil/plane.py b/lentil/plane.py index 0dda3c8..86730ef 100644 --- a/lentil/plane.py +++ b/lentil/plane.py @@ -12,8 +12,7 @@ class Plane: - """ - Base class for representing a finite geometric plane. + """Base class for representing a finite geometric plane. Parameters ---------- @@ -72,6 +71,7 @@ def __init__(self, amplitude=1, opd=0, mask=None, pixelscale=None, diameter=None self._ptype = lentil.ptype(ptype) self.tilt = [] + self._tilt_opd = 0 def __repr__(self): return f'{self.__class__.__name__}()' @@ -130,8 +130,7 @@ def mask(self): @property def global_mask(self): - """ - Flattened view of :attr:`mask` + """Flattened view of :attr:`mask` Returns ------- @@ -182,8 +181,7 @@ def diameter(self): @property def shape(self): - """ - Plane dimensions computed from :attr:`mask`. + """Plane dimensions computed from :attr:`mask`. Returns (mask.shape[1], mask.shape[2]) if :attr:`size: > 1. Returns None if :attr:`mask` is None. @@ -199,8 +197,7 @@ def shape(self): @property def size(self): - """ - Number of independent masks (segments) in :attr:`mask` + """Number of independent masks (segments) in :attr:`mask` Returns ------- @@ -214,8 +211,7 @@ def size(self): @property def ptt_vector(self): - """ - 2D vector representing piston and tilt in x and y. + """2D vector representing piston and tilt in x and y. Planes with no mask have :attr:`ptt_vector` = None. @@ -253,8 +249,7 @@ def ptt_vector(self): return ptt_vector def copy(self): - """ - Make a copy of this object. + """Make a copy of this object. Returns ------- @@ -263,8 +258,7 @@ def copy(self): return copy.deepcopy(self) def fit_tilt(self, inplace=False): - """ - Fit and remove tilt from Plane :attr:`opd` via least squares. The + """Fit and remove tilt from Plane :attr:`opd` via least squares. The equivalent angular tilt is bookkept in Plane :attr:`tilt`. Parameters @@ -291,23 +285,29 @@ def fit_tilt(self, inplace=False): if self.size == 1: t = np.linalg.lstsq(ptt_vector.T, plane.opd.ravel(), rcond=None)[0] - opd_tilt = np.einsum('ij,i->j', ptt_vector[1:3], t[1:3]) - plane.opd -= opd_tilt.reshape(plane.opd.shape) + tilt_opd = np.einsum('ij,i->j', ptt_vector[1:3], t[1:3]) + tilt_opd = tilt_opd.reshape(plane.opd.shape) plane.tilt.append(Tilt(x=t[1], y=t[2])) else: t = np.empty((self.size, 3)) - opd_no_tilt = np.empty((self.size, plane.opd.shape[0], plane.opd.shape[1])) + #opd_no_tilt = np.empty((self.size, plane.opd.shape[0], plane.opd.shape[1])) # iterate over the segments and compute the tilt term for seg in np.arange(self.size): t[seg] = np.linalg.lstsq(ptt_vector[3 * seg:3 * seg + 3].T, plane.opd.ravel(), rcond=None)[0] seg_tilt = np.einsum('ij,i->j', ptt_vector[3 * seg + 1:3 * seg + 3], t[seg, 1:3]) - opd_no_tilt[seg] = (plane.opd - seg_tilt.reshape(plane.opd.shape)) * self.mask[seg] + tilt_opd += seg_tilt.reshape(plane.opd.shape) * plane.mask[seg] + #opd_no_tilt[seg] = (plane.opd - seg_tilt.reshape(plane.opd.shape)) * self.mask[seg] - plane.opd = np.sum(opd_no_tilt, axis=0) + #plane.opd = np.sum(opd_no_tilt, axis=0) plane.tilt.extend([Tilt(x=t[seg, 1], y=t[seg, 2]) for seg in range(self.size)]) + + if isinstance(getattr(type(self), 'opd'), lentil.dependent): + plane._tilt_opd = tilt_opd + else: + plane.opd -= tilt_opd return plane @@ -540,7 +540,6 @@ def _plane_slice(mask): -------- helper.boundary_slice Plane.slice_offset - """ # self.mask may still return None so we catch that here @@ -594,7 +593,6 @@ class Pupil(Plane): aberrations in the optical system appear as deviations from this perfect sphere. The primary use of :class:`Pupil` is to represent this spherical wavefront. - """ def __init__(self, amplitude=1, opd=0, mask=None, pixelscale=None, @@ -683,8 +681,8 @@ class TiltInterface(Plane): -------- Tilt DispersiveTilt - """ + def __init__(self, **kwargs): # if ptype is provided as a kwarg use that, otherwise default # to lentil.tilt @@ -733,7 +731,6 @@ class Tilt(TiltInterface): Radians of tilt about the x-axis y : float Radians of tilt about the y-axis - """ def __init__(self, x, y, **kwargs): super().__init__(**kwargs) @@ -819,8 +816,8 @@ class DispersiveTilt(TiltInterface): speed and precision of modeling grisms with higher order trace and/or dispersion functions. In cases where speed or accuracy are extremely important, a custom solution may be required. - """ + def __init__(self, trace, dispersion, **kwargs): super().__init__(**kwargs) @@ -858,8 +855,8 @@ def _dispersion(self, wavelength): ------- distance: float Distance along spectral trace relative to self.dispersion[-1] - """ + if self._dispersion_order == 1: # For a first order polynomial we have lambda = dispersion[0] * dist + dispersion[1] # Solving for distance gives dist = (lambda - dispersion[1])/dispersion[0] @@ -922,8 +919,8 @@ def _arc_len(dist_func, a, b): References ---------- https://en.wikipedia.org/wiki/Arc_length#Finding_arc_lengths_by_integrating - """ + return scipy.integrate.quad(dist_func, a, b)[0] @@ -998,8 +995,8 @@ class Grism(DispersiveTilt): .. deprecated:: 1.0.0 `Grism` will be removed in Lentil v1.0.0, it is replaced by `DispersiveTilt`. - """ + def __init__(self, trace, dispersion, **kwargs): warn('lentil.Grism will be deprecated in v1.0.0, it is ' 'replaced by lentil.DispersiveTilt.', @@ -1030,8 +1027,8 @@ class Rotate(Plane): If the angle is an even multiple of 90 degrees, ``numpy.rot90`` is used to perform the rotation rather than ``scipy.ndimage.rotate``. In this case, the order parameter is irrelevant because no interpolation occurs. - """ + def __init__(self, angle=0, unit='degrees', order=3): super().__init__() @@ -1052,7 +1049,6 @@ def multiply(self, wavefront): ------- wavefront : :class:`~lentil.wavefront.Wavefront` object Updated wavefront - """ pixelscale = lentil.field.multiply_pixelscale(self.pixelscale, wavefront.pixelscale) @@ -1085,8 +1081,8 @@ class Flip(Plane): Axis or axes along which to flip over. The default, axis=None, will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis. - """ + def __init__(self, axis=None): super().__init__() self.axis = axis @@ -1103,8 +1099,8 @@ def multiply(self, wavefront): ------- wavefront : :class:`~lentil.wavefront.Wavefront` object Updated wavefront - """ + out = wavefront.copy() for field in out.data: field.data = np.flip(field.data, axis=self.axis) @@ -1114,18 +1110,46 @@ def multiply(self, wavefront): class Quadratic(Plane): """Base class for representing an optical plane with a quadratic phase term. - """ pass class Conic(Plane): """Base class for representing an optical plane with a conic phase term. - """ pass +class dependent: + """Decorate a :class:`~lentil.Plane` ``opd`` method to mark its return + value as being calculated rather than static. + """ + + def __init__(self, func): + if func.__name__ != 'opd': + raise AttributeError('dependent can only decorate opd method') + self.func = func + + def __get__(self, plane, type=None): + + # This if statement allows us to access the descriptor + # from within Plane via getattr(type(self), 'opd'). By checking + # the type of the return, we know if we have a regular attribute + # (primitive type), a property, or a dependent instance. + # https://stackoverflow.com/a/21629855 + if plane is None: + return self + + opd = self.func(plane) + return opd - plane._tilt + + def __set__(self, *args, **kwargs): + raise AttributeError('can\'t set dependent opd') + + def setter(self, *args, **kwargs): + raise AttributeError('can\'t set dependent opd') + + # def Q(self, wave, pixelscale, oversample=1): # return (self.f_number*wave*oversample)/pixelscale #