Skip to content

Commit

Permalink
Deprecate in-place operations (resolves #43)
Browse files Browse the repository at this point in the history
  • Loading branch information
andykee committed Nov 15, 2023
1 parent d52286f commit 59bb57b
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 157 deletions.
20 changes: 0 additions & 20 deletions docs/_img/python/propagate_copy.py

This file was deleted.

4 changes: 2 additions & 2 deletions docs/patterns/radiometry/propagation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ and multiply the source irradiance by the collecting area to get photons/second.
>>> psf = np.zeros((64, 64))
>>> for wl, wt in zip(source.wave, source.value):
... w = lentil.Wavefront(wl*1e-9)
... w *= pupil
... w = w * pupil
... w = w.propagate_image(pixelscale=5e-6, npix=32, oversample=2)
... psf += (w.intensity * (np.pi*(pupil_diameter/2)**2))
>>> plt.imshow(psf, origin='lower')
Expand All @@ -56,7 +56,7 @@ fine features present in the source's spectral response.
>>> psf = np.zeros((64, 64))
>>> for wl, wt in zip(binned_wave, binned_flux):
... w = lentil.Wavefront(wl*1e-9)
... w *= pupil
... w = w * pupil
... w = w.propagate_image(pixelscale=5e-6, npix=32, oversample=2)
... psf += (w.intensity * (np.pi*(pupil_diameter/2)**2))
>>> plt.imshow(psf, origin='lower')
Expand Down
40 changes: 2 additions & 38 deletions docs/user_guide/diffraction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,6 @@ follows the same basic flow:
>>> w2.focal_length
10
It is also possible to perform the multiplication in-place, reducing the memory footprint
of the propagation:

.. code-block:: pycon
>>> w1 *= pupil
.. note::

Additional details on the plane-wavefront interaction can be found in
Expand Down Expand Up @@ -138,35 +131,6 @@ follows the same basic flow:
are required to model the desired optical system, steps 2 and 3 should be
repeated until the |Wavefront| has been propagated through all of the planes.

Performing propagations in-place vs. on copies
----------------------------------------------
By default, all propagation operations operate on a |Wavefront| in-place. If desired,
a copy can be returned instead by providing the argument ``inplace=False``:

.. code-block:: python
:emphasize-lines: 9
import matplotlib.pyplot as plt
import lentil
pupil = lentil.Pupil(amplitude=lentil.circle((256, 256), 120),
pixelscale=1/240, focal_length=10)
w1 = lentil.Wavefront(650e-9)
w2 = w1 * pupil
w3 = w2.propagate_image(pixelscale=5e-6, npix=64, oversample=5, inplace=False)
plt.subplot(121)
plt.imshow(w2.intensity, origin='lower')
plt.title('w2 intensity')
plt.subplot(122)
plt.imshow(w3.intensity**0.1, origin='lower')
plt.title('w3 intensity')
.. plot:: _img/python/propagate_copy.py
:scale: 50

Broadband (multi-wavelength) propagations
-----------------------------------------
The steps outlined above propagate a single monochromatic |Wavefront| through an
Expand All @@ -190,7 +154,7 @@ different wavelengths and accumulates the resulting image plane intensity:

for wl in wavelengths:
w = lentil.Wavefront(wl)
w *= pupil
w = w * pupil
w.propagate_image(pixelscale=5e-6, npix=64, oversample=5)
img += w.intensity

Expand All @@ -213,7 +177,7 @@ wavefront intensity given by ``npix`` * ``oversample``.
for wl in wavelengths:
w = lentil.Wavefront(wl)
w *= pupil
w = w * pupil
w.propagate_image(pixelscale=5e-6, npix=64, oversample=5)
img = w.insert(img)
Expand Down
15 changes: 0 additions & 15 deletions docs/user_guide/optical_systems.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,6 @@ the wavefront's complex data array:
\mathbf{W_1} = \mathbf{A} \exp\left(\frac{2\pi j}{\lambda} \mathbf{\theta}\right) \circ \mathbf{W_0}
The plane's :func:`~lentil.Plane.multiply` method also accepts an ``inplace`` argument
that governs whether the multiplication operation is performed on the wavefront in-place
or using a copy:

.. code:: pycon
>>> w1 = plane.multiply(w0, inplace=True)
>>> w1 is w0
True
The in-place multiplication operator can also be used:

.. code:: pycon
>>> w *= plane
.. If the |Plane| :attr:`~lentil.Plane.tilt` attribute is not empty, its contents are appended
.. to the |Wavefront|. See :ref:`user_guide.planes.fit_tilt` and :ref:`user_guide.diffraction.tilt`
Expand Down
59 changes: 20 additions & 39 deletions lentil/plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def fit_tilt(self, inplace=False):

return plane

def rescale(self, scale, inplace=False):
def rescale(self, scale):
"""
Rescale a plane via interpolation.
Expand All @@ -251,9 +251,6 @@ def rescale(self, scale, inplace=False):
scale : float
Scale factor for interpolation. Scale factors less than 1 shrink the
Plane while scale factors greater than 1 grow it.
inplace : bool, optional
Modify the original object in place (True) or create a copy (False,
default)
Returns
-------
Expand All @@ -268,10 +265,7 @@ def rescale(self, scale, inplace=False):
Plane.resample
"""
if inplace:
plane = self
else:
plane = self.copy()
plane = self.copy()

if plane.amplitude.ndim > 1:
plane.amplitude = lentil.rescale(plane.amplitude, scale=scale, shape=None,
Expand Down Expand Up @@ -302,7 +296,7 @@ def rescale(self, scale, inplace=False):

return plane

def resample(self, pixelscale, inplace=False):
def resample(self, pixelscale):
"""Resample a plane via interpolation.
The following Plane attributes are resampled:
Expand All @@ -317,9 +311,6 @@ def resample(self, pixelscale, inplace=False):
----------
pixelscale : float
Desired Plane pixelscale.
inplace : bool, optional
Modify the original object in place (True) or create a copy (False,
default)
Returns
-------
Expand All @@ -340,19 +331,15 @@ def resample(self, pixelscale, inplace=False):
elif self.pixelscale[0] != self.pixelscale[1]:
raise NotImplementedError("Can't resample non-uniformly sampled Plane")

return self.rescale(scale=self.pixelscale[0]/pixelscale, inplace=inplace)
return self.rescale(scale=self.pixelscale[0]/pixelscale)

def multiply(self, wavefront, inplace=False):
def multiply(self, wavefront):
"""Multiply with a wavefront
Parameters
----------
wavefront : :class:`~lentil.wavefront.Wavefront` object
Wavefront to be multiplied
inplace : bool, optional
If True, the wavefront object is multiplied in-place, otherwise a
copy is created before performing the multiplication. Default is
False.
Note
----
Expand All @@ -374,16 +361,10 @@ def multiply(self, wavefront, inplace=False):
shape = wavefront.shape if self.shape == () else self.shape
data = wavefront.data

if inplace:
out = wavefront
out.data = []
out.shape = shape
out.pixelscale = pixelscale
else:
out = lentil.Wavefront.empty(wavelength=wavefront.wavelength,
pixelscale=pixelscale,
focal_length=wavefront.focal_length,
shape=shape)
out = lentil.Wavefront.empty(wavelength=wavefront.wavelength,
pixelscale=pixelscale,
focal_length=wavefront.focal_length,
shape=shape)


for field in data:
Expand Down Expand Up @@ -526,9 +507,9 @@ def __init__(self, focal_length=None, pixelscale=None, amplitude=1,
def __init_subclass__(cls):
cls._focal_length = None

def multiply(self, wavefront, inplace=False):
def multiply(self, wavefront):

wavefront = super().multiply(wavefront, inplace)
wavefront = super().multiply(wavefront)

# we inherit the plane's focal length as the wavefront's focal length
wavefront.focal_length = self.focal_length
Expand Down Expand Up @@ -584,8 +565,8 @@ def fit_tilt(self, *args, **kwargs):
# # np.s_[...] = Ellipsis -> returns the whole array
# return [np.s_[...]]

def multiply(self, wavefront, inplace=False):
wavefront = super().multiply(wavefront, inplace)
def multiply(self, wavefront):
wavefront = super().multiply(wavefront)
wavefront.ptype = lentil.image
return wavefront

Expand Down Expand Up @@ -616,7 +597,7 @@ class Detector(Image):

class DispersivePhase(Plane):

def multiply(self, wavefront, inplace=False):
def multiply(self, wavefront):
# NOTE: we can handle wavelength-dependent phase terms here (e.g. chromatic
# aberrations). Since the phase will vary by wavelength, we can't fit out the
# tilt pre-propagation and apply the same tilt for each wavelength like we can
Expand All @@ -629,8 +610,8 @@ class DispersiveShift(Plane):
def shift(self, wavelength, x0, y0, **kwargs):
raise NotImplementedError

def multiply(self, wavefront, inplace=False):
wavefront = super().multiply(wavefront, inplace=False)
def multiply(self, wavefront):
wavefront = super().multiply(wavefront)
for field in wavefront.data:
field.tilt.append(self)
return wavefront
Expand Down Expand Up @@ -824,8 +805,8 @@ def __init__(self, x, y):
self.x = y # y tilt is about the x-axis.
self.y = x # x tilt is about the y-axis.

def multiply(self, wavefront, inplace=False):
wavefront = super().multiply(wavefront, inplace)
def multiply(self, wavefront):
wavefront = super().multiply(wavefront)
for field in wavefront.data:
field.tilt.append(self)
return wavefront
Expand Down Expand Up @@ -881,7 +862,7 @@ def __init__(self, angle=0, unit='degrees', order=3):
self.angle = -angle
self.order = order

def multiply(self, wavefront, inplace=False):
def multiply(self, wavefront):
"""Multiply with a wavefront
Parameters
Expand Down Expand Up @@ -932,7 +913,7 @@ def __init__(self, axis=None):
super().__init__()
self.axis = axis

def multiply(self, wavefront, inplace=False):
def multiply(self, wavefront):
"""Multiply with a wavefront
Parameters
Expand Down
20 changes: 5 additions & 15 deletions lentil/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lentil.wavefront import Wavefront

def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,
oversample=2, inplace=True):
oversample=2):
"""Propagate a Wavefront using Fraunhofer diffraction.
Parameters
Expand All @@ -25,9 +25,6 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,
``prop_shape`` should not be larger than ``prop``.
oversample : int, optional
Number of times to oversample the output plane. Default is 2.
inplace : bool, optional
If True (default) the Wavefront is propagated in-place, otherwise
a copy is created and propagated.
Returns
-------
Expand All @@ -48,17 +45,10 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,

data = wavefront.data

if inplace:
out = wavefront
out.data = []
out.pixelscale = du/oversample
out.shape = shape_out
out.ptype = ptype_out
else:
out = Wavefront.empty(wavelength=wavefront.wavelength,
pixelscale = du/oversample,
shape = shape_out,
ptype = ptype_out)
out = Wavefront.empty(wavelength=wavefront.wavelength,
pixelscale = du/oversample,
shape = shape_out,
ptype = ptype_out)

for field in data:
# compute the field shift from any embedded tilts. note the return value
Expand Down
5 changes: 1 addition & 4 deletions lentil/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ def __init__(self, wavelength, pixelscale=None, diameter=None, focal_length=None
tilt=tilt)]

def __mul__(self, plane):
return plane.multiply(self, inplace=False)

def __imul__(self, plane):
return plane.multiply(self, inplace=True)
return plane.multiply(self)

def __rmul__(self, other):
return self.__mul__(other)
Expand Down
12 changes: 1 addition & 11 deletions tests/test_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,6 @@ def test_wavefront_plane_multiply():
assert np.array_equal(w1.data[0].data, phasor)


def test_wavefront_plane_multiply_inplace():
p = RandomPlane()
w = lentil.Wavefront(650e-9)

w_copy = p.multiply(w, inplace=False)
w_inplace = p.multiply(w, inplace=True)

assert w_copy is not w
assert w_inplace is w

def test_wavefront_plane_multiply_overlapping_segment_slices():
seg = lentil.hexagon((64, 64), 32)
seg = seg[5:60, :]
Expand Down Expand Up @@ -94,7 +84,7 @@ def test_wavefront_pupil_multiply():

def test_pupil_rescale_power():
p = CircularPupil()
pr = p.rescale(3, inplace=False)
pr = p.rescale(3)

amp_power = np.sum(np.abs(p.amplitude)**2)
ampr_power = np.sum(np.abs(pr.amplitude)**2)
Expand Down
14 changes: 1 addition & 13 deletions tests/test_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_propagate_resample():
w *= p
wi = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=10)

p2 = p.rescale(scale=3, inplace=False)
p2 = p.rescale(scale=3)
w2 = lentil.Wavefront(650e-9)
w2 *= p2
w2i = lentil.propagate_dft(w2, shape=(64,64), pixelscale=5e-6, oversample=10)
Expand All @@ -283,15 +283,3 @@ def test_propagate_resample():

assert np.allclose(cent, [320, 320])
assert math.isclose(np.sum(wi.intensity), np.sum(w2i.intensity), rel_tol=1e-2)


def test_propagate_image_inplace():
p = lentil.Pupil(focal_length=10, pixelscale=1 / 240,
amplitude=lentil.circle((256, 256), 120))
w = lentil.Wavefront(650e-9)
w *= p
w_copy = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=2, inplace=False)
w_inplace = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=2, inplace=True)

assert w_copy is not w
assert w_inplace is w

0 comments on commit 59bb57b

Please sign in to comment.