diff --git a/format_code.sh b/format_code.sh index 41bf6e10..f3ca3730 100755 --- a/format_code.sh +++ b/format_code.sh @@ -4,3 +4,4 @@ black *.py -l 100 black lensless/*.py -l 100 black scripts/*.py -l 100 black profile/*.py -l 100 +black test/*.py -l 100 diff --git a/lensless/admm.py b/lensless/admm.py index aefc7894..a51bdf18 100644 --- a/lensless/admm.py +++ b/lensless/admm.py @@ -195,7 +195,7 @@ def _update(self): def _form_image(self): image = self._crop(self._image_est) image[image < 0] = 0 - return image + return image.squeeze() def set_data(self, data): if not self._is_rgb: diff --git a/lensless/apgd.py b/lensless/apgd.py index 9892f01e..025f6d6f 100644 --- a/lensless/apgd.py +++ b/lensless/apgd.py @@ -167,7 +167,7 @@ def set_data(self, data): self._gen = self._apgd.iterates(n=self._max_iter) def reset(self): - self._image_est = np.zeros(self._original_size).astype(self._dtype) + self._image_est = np.zeros(self._original_size, dtype=self._dtype) if self._apgd is not None: self._apgd.reset() @@ -188,7 +188,7 @@ def _progress(self): def _update(self): next(self._gen) - self._image_est = self._apgd.iterand["iterand"] + self._image_est[:] = self._apgd.iterand["iterand"] def _form_image(self): image = self._image_est.reshape(self._original_shape) diff --git a/lensless/gradient_descent.py b/lensless/gradient_descent.py index aa941273..a17694f8 100644 --- a/lensless/gradient_descent.py +++ b/lensless/gradient_descent.py @@ -77,7 +77,7 @@ def reset(self): # for online approach could use last reconstruction psf_flat = self._psf.reshape(-1, self._n_channels) pixel_start = (np.max(psf_flat, axis=0) + np.min(psf_flat, axis=0)) / 2 - x = np.ones(self._psf_shape) * pixel_start + x = np.ones(self._psf_shape, dtype=self._dtype) * pixel_start self._image_est = self._pad(x) # spatial frequency response @@ -105,7 +105,7 @@ def _update(self): self._image_est = self._proj(self._image_est) def _form_image(self): - return self._proj(self._crop(self._image_est)) + return self._proj(self._crop(self._image_est)).squeeze() class NesterovGradientDescent(GradientDescient): diff --git a/lensless/io.py b/lensless/io.py index cf445ff9..29976cb9 100644 --- a/lensless/io.py +++ b/lensless/io.py @@ -33,9 +33,9 @@ def load_image( bayer : bool Whether input data is Bayer. blue_gain : float - Blue gain. + Blue gain for color correction. red_gain : float - Red gain. + Red gain for color correction. black_level : float Black level. Default is to use that of Raspberry Pi HQ camera. ccm : :py:class:`~numpy.ndarray` @@ -140,7 +140,22 @@ def load_psf( Whether to return background level, for removing from data for reconstruction. flip : bool, optional Whether to flip up-down and left-right. - verbose + verbose : bool + Whether to print metadata. + bayer : bool + Whether input data is Bayer. + blue_gain : float + Blue gain for color correction. + red_gain : float + Red gain for color correction. + dtype : float32 or float64 + Data type of returned data. + nbits_out : int + Output bit depth. Default is to use that of input. + single_psf : bool + Whether to sum RGB channels into single PSF, same across channels. Done + in "Learned reconstructions for practical mask-based lensless imaging" + of Kristina Monakhova et. al. Returns ------- @@ -221,8 +236,6 @@ def load_data( Full path to PSF file. data_fp : str Full path to measurement file. - source : "white", "red", "green", or "blue" - Light source used to measure PSF. downsample : int or float Downsampling factor. bg_pix : tuple, optional @@ -231,10 +244,24 @@ def load_data( recommended. plot : bool, optional Whether or not to plot PSF and raw data. - flip : bool, optional - Whether to flip data. - cv : bool, optional - Whether image was saved with OpenCV. If not colors need to be swapped. + flip : bool + Whether to flip data (vertical and horizontal). + bayer : bool + Whether input data is Bayer. + blue_gain : float + Blue gain for color correction. + red_gain : float + Red gain for color correction. + gamma : float, optional + Optional gamma factor to apply, ONLY for plotting. Default is None. + gray : bool + Whether to load as grayscale or RGB. + dtype : float32 or float64 + Data type of returned data. + single_psf : bool + Whether to sum RGB channels into single PSF, same across channels. Done + in "Learned reconstructions for practical mask-based lensless imaging" + of Kristina Monakhova et. al. Returns ------- @@ -283,4 +310,7 @@ def load_data( ax = plot_image(data, gamma=gamma) ax.set_title("Raw data") + psf = np.array(psf, dtype=dtype) + data = np.array(data, dtype=dtype) + return psf, data diff --git a/setup.py b/setup.py index 18e4b3ba..462889fd 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ python_requires=">=3.6", install_requires=[ "opencv-python==4.5.1.48", - "numpy==1.20.3", + "numpy>=1.21", "scipy==1.7.0", "click==8.0.1", "image==1.5.33", diff --git a/test/README.md b/test/README.md new file mode 100644 index 00000000..694aca8b --- /dev/null +++ b/test/README.md @@ -0,0 +1,9 @@ +Install `pytest` +``` +pip install pytest +``` +And then run +``` +pytest test/ +``` +to run all tests. \ No newline at end of file diff --git a/test/test_algos.py b/test/test_algos.py new file mode 100644 index 00000000..18a7167c --- /dev/null +++ b/test/test_algos.py @@ -0,0 +1,35 @@ +import numpy as np +from lensless.io import load_data +from lensless import GradientDescient, NesterovGradientDescent, FISTA, ADMM, APGD + + +psf_fp = "data/psf/tape_rgb.png" +data_fp = "data/raw_data/thumbs_up_rgb.png" +downsample = 16 +n_iter = 5 +disp = None + + +def test_algo(): + for algo in [GradientDescient, NesterovGradientDescent, FISTA, ADMM, APGD]: + for gray in [True, False]: + for dtype in [np.float32, np.float64]: + psf, data = load_data( + psf_fp=psf_fp, + data_fp=data_fp, + downsample=downsample, + plot=False, + gray=gray, + dtype=dtype, + ) + recon = algo(psf, dtype=dtype) + recon.set_data(data) + res = recon.apply(n_iter=n_iter, disp_iter=None, plot=False) + if gray: + assert len(psf.shape) == 2 + else: + assert len(psf.shape) == 3 + assert res.dtype == dtype, f"Got {res.dtype}, expected {dtype}" + + +test_algo() diff --git a/test/test_io.py b/test/test_io.py new file mode 100644 index 00000000..e66d0cb4 --- /dev/null +++ b/test/test_io.py @@ -0,0 +1,29 @@ +from lensless.io import load_data +import numpy as np + +psf_fp = "data/psf/tape_rgb.png" +data_fp = "data/raw_data/thumbs_up_rgb.png" +downsample = 8 + + +def test_load_data(): + for gray in [True, False]: + for dtype in [np.float32, np.float64]: + psf, data = load_data( + psf_fp=psf_fp, + data_fp=data_fp, + downsample=downsample, + plot=False, + gray=gray, + dtype=dtype, + ) + if gray: + assert len(psf.shape) == 2 + else: + assert len(psf.shape) == 3 + assert psf.shape == data.shape + assert psf.dtype == dtype, dtype + assert data.dtype == dtype, dtype + + +test_load_data()