From 09f87d5e7a634a87f1a858d08b85c7c76887bdb4 Mon Sep 17 00:00:00 2001 From: Andy Kee Date: Thu, 30 Nov 2023 21:17:45 -0800 Subject: [PATCH] Allow 'amp' as an alias for 'amplitude' --- lentil/plane.py | 44 +++++++++++++++++++++++++++++++++----------- tests/test_plane.py | 10 ++++++++++ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/lentil/plane.py b/lentil/plane.py index 094fe98..47c230e 100644 --- a/lentil/plane.py +++ b/lentil/plane.py @@ -21,7 +21,8 @@ class Plane: Electric field amplitude transmission. Amplitude should be normalized with :func:`~lentil.normalize_power` if conservation of power through diffraction propagation is required. If not specified (default), - amplitude is created which has no effect on wavefront propagation. + amplitude is created which has no effect on wavefront propagation. Can + also be specified using the ``amp`` keyword. phase : array_like, optional Phase change caused by plane. If not specified (default), phase is created which has no effect on wavefront propagation. @@ -33,7 +34,7 @@ class Plane: .. plot:: _img/python/segmask.py :scale: 50 - + pixelscale : float or (2,) array_like, optional Physical sampling of each pixel in the plane. If ``pixelscale`` is a scalar, uniform sampling in x and y is assumed. If None (default), @@ -43,10 +44,17 @@ class Plane: from the boundary of :attr:`mask`. ptype : ptype object Plane type - """ + def __init__(self, amplitude=1, phase=0, mask=None, pixelscale=None, diameter=None, - ptype=None): + ptype=None, **kwargs): + + if 'amp' in kwargs.keys(): + if amplitude != 1: + raise TypeError("Got both 'amplitude' and 'amp', " + "which are aliases of one another") + amplitude = kwargs['amp'] + self.amplitude = np.asarray(amplitude) self.phase = np.asarray(phase) self.mask = mask @@ -560,7 +568,7 @@ class Pupil(Plane): with :func:`~lentil.normalize_power` if conservation of power through a diffraction propagation is required. If not specified, a default amplitude is created which has no effect on wavefront - propagation. + propagation. Can also be specified using the ``amp`` keyword. phase : array_like, optional Phase change caused by plane. If not specified, a default phase is created which has no effect on wavefront propagation. @@ -582,10 +590,10 @@ class Pupil(Plane): """ def __init__(self, focal_length=None, pixelscale=None, amplitude=1, - phase=0, mask=None): + phase=0, mask=None, **kwargs): super().__init__(pixelscale=pixelscale, amplitude=amplitude, phase=phase, - mask=mask, ptype=lentil.pupil) + mask=mask, ptype=lentil.pupil, **kwargs) self.focal_length = focal_length @@ -613,6 +621,19 @@ class Image(Plane): Number of pixels as (rows, cols). If a single value is provided, :class:`Image` is assumed to be square with nrows = ncols = shape. Default is None. + amplitude : array_like, optional + Electric field amplitude transmission. Amplitude should be normalized + with :func:`~lentil.normalize_power` if conservation of power + through a diffraction propagation is required. If not specified, a + default amplitude is created which has no effect on wavefront + propagation. Can also be specified using the ``amp`` keyword. + mask : array_like, optional + Binary mask. If not specified, a mask is created from the amplitude. + + Other Parameters + ---------------- + **kwargs : :class:`Plane` parameters + Keyword arguments passed to :class:`~lentil.Plane` constructor Notes ----- @@ -625,10 +646,11 @@ class Image(Plane): """ - def __init__(self, shape=None, pixelscale=None, amplitude=1, phase=0, - mask=None): - super().__init__(amplitude=amplitude, phase=phase, mask=mask, - pixelscale=pixelscale, ptype=lentil.image) + def __init__(self, shape=None, pixelscale=None, amplitude=1, mask=None, + **kwargs): + super().__init__(amplitude=amplitude, mask=mask, + pixelscale=pixelscale, ptype=lentil.image, + **kwargs) self.shape = shape @property diff --git a/tests/test_plane.py b/tests/test_plane.py index 2cf9463..3d80de5 100644 --- a/tests/test_plane.py +++ b/tests/test_plane.py @@ -1,6 +1,7 @@ import math import numpy as np +import pytest import lentil @@ -26,6 +27,15 @@ def test_default_plane(): assert p.mask == p.amplitude +def test_amp_alias(): + p = lentil.Plane(amp=10) + assert p.amplitude == 10 + + +def test_amp_alias_error(): + with pytest.raises(TypeError): + lentil.Plane(amplitude=10, amp=10) + def test_plane_fit_tilt_inplace(): p = RandomPlane() p_copy = p.fit_tilt(inplace=False)