Skip to content

Commit

Permalink
Merge pull request #17 from LCAV/tests
Browse files Browse the repository at this point in the history
Add unit tests, and update numpy.
  • Loading branch information
ebezzam authored Mar 21, 2022
2 parents b481beb + 4f45e73 commit 5a9ba47
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 15 deletions.
1 change: 1 addition & 0 deletions format_code.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ black *.py -l 100
black lensless/*.py -l 100
black scripts/*.py -l 100
black profile/*.py -l 100
black test/*.py -l 100
2 changes: 1 addition & 1 deletion lensless/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _update(self):
def _form_image(self):
image = self._crop(self._image_est)
image[image < 0] = 0
return image
return image.squeeze()

def set_data(self, data):
if not self._is_rgb:
Expand Down
4 changes: 2 additions & 2 deletions lensless/apgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def set_data(self, data):
self._gen = self._apgd.iterates(n=self._max_iter)

def reset(self):
self._image_est = np.zeros(self._original_size).astype(self._dtype)
self._image_est = np.zeros(self._original_size, dtype=self._dtype)
if self._apgd is not None:
self._apgd.reset()

Expand All @@ -188,7 +188,7 @@ def _progress(self):

def _update(self):
next(self._gen)
self._image_est = self._apgd.iterand["iterand"]
self._image_est[:] = self._apgd.iterand["iterand"]

def _form_image(self):
image = self._image_est.reshape(self._original_shape)
Expand Down
4 changes: 2 additions & 2 deletions lensless/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def reset(self):
# for online approach could use last reconstruction
psf_flat = self._psf.reshape(-1, self._n_channels)
pixel_start = (np.max(psf_flat, axis=0) + np.min(psf_flat, axis=0)) / 2
x = np.ones(self._psf_shape) * pixel_start
x = np.ones(self._psf_shape, dtype=self._dtype) * pixel_start
self._image_est = self._pad(x)

# spatial frequency response
Expand Down Expand Up @@ -105,7 +105,7 @@ def _update(self):
self._image_est = self._proj(self._image_est)

def _form_image(self):
return self._proj(self._crop(self._image_est))
return self._proj(self._crop(self._image_est)).squeeze()


class NesterovGradientDescent(GradientDescient):
Expand Down
48 changes: 39 additions & 9 deletions lensless/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def load_image(
bayer : bool
Whether input data is Bayer.
blue_gain : float
Blue gain.
Blue gain for color correction.
red_gain : float
Red gain.
Red gain for color correction.
black_level : float
Black level. Default is to use that of Raspberry Pi HQ camera.
ccm : :py:class:`~numpy.ndarray`
Expand Down Expand Up @@ -140,7 +140,22 @@ def load_psf(
Whether to return background level, for removing from data for reconstruction.
flip : bool, optional
Whether to flip up-down and left-right.
verbose
verbose : bool
Whether to print metadata.
bayer : bool
Whether input data is Bayer.
blue_gain : float
Blue gain for color correction.
red_gain : float
Red gain for color correction.
dtype : float32 or float64
Data type of returned data.
nbits_out : int
Output bit depth. Default is to use that of input.
single_psf : bool
Whether to sum RGB channels into single PSF, same across channels. Done
in "Learned reconstructions for practical mask-based lensless imaging"
of Kristina Monakhova et. al.
Returns
-------
Expand Down Expand Up @@ -221,8 +236,6 @@ def load_data(
Full path to PSF file.
data_fp : str
Full path to measurement file.
source : "white", "red", "green", or "blue"
Light source used to measure PSF.
downsample : int or float
Downsampling factor.
bg_pix : tuple, optional
Expand All @@ -231,10 +244,24 @@ def load_data(
recommended.
plot : bool, optional
Whether or not to plot PSF and raw data.
flip : bool, optional
Whether to flip data.
cv : bool, optional
Whether image was saved with OpenCV. If not colors need to be swapped.
flip : bool
Whether to flip data (vertical and horizontal).
bayer : bool
Whether input data is Bayer.
blue_gain : float
Blue gain for color correction.
red_gain : float
Red gain for color correction.
gamma : float, optional
Optional gamma factor to apply, ONLY for plotting. Default is None.
gray : bool
Whether to load as grayscale or RGB.
dtype : float32 or float64
Data type of returned data.
single_psf : bool
Whether to sum RGB channels into single PSF, same across channels. Done
in "Learned reconstructions for practical mask-based lensless imaging"
of Kristina Monakhova et. al.
Returns
-------
Expand Down Expand Up @@ -283,4 +310,7 @@ def load_data(
ax = plot_image(data, gamma=gamma)
ax.set_title("Raw data")

psf = np.array(psf, dtype=dtype)
data = np.array(data, dtype=dtype)

return psf, data
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
python_requires=">=3.6",
install_requires=[
"opencv-python==4.5.1.48",
"numpy==1.20.3",
"numpy>=1.21",
"scipy==1.7.0",
"click==8.0.1",
"image==1.5.33",
Expand Down
9 changes: 9 additions & 0 deletions test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Install `pytest`
```
pip install pytest
```
And then run
```
pytest test/
```
to run all tests.
35 changes: 35 additions & 0 deletions test/test_algos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
from lensless.io import load_data
from lensless import GradientDescient, NesterovGradientDescent, FISTA, ADMM, APGD


psf_fp = "data/psf/tape_rgb.png"
data_fp = "data/raw_data/thumbs_up_rgb.png"
downsample = 16
n_iter = 5
disp = None


def test_algo():
for algo in [GradientDescient, NesterovGradientDescent, FISTA, ADMM, APGD]:
for gray in [True, False]:
for dtype in [np.float32, np.float64]:
psf, data = load_data(
psf_fp=psf_fp,
data_fp=data_fp,
downsample=downsample,
plot=False,
gray=gray,
dtype=dtype,
)
recon = algo(psf, dtype=dtype)
recon.set_data(data)
res = recon.apply(n_iter=n_iter, disp_iter=None, plot=False)
if gray:
assert len(psf.shape) == 2
else:
assert len(psf.shape) == 3
assert res.dtype == dtype, f"Got {res.dtype}, expected {dtype}"


test_algo()
29 changes: 29 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from lensless.io import load_data
import numpy as np

psf_fp = "data/psf/tape_rgb.png"
data_fp = "data/raw_data/thumbs_up_rgb.png"
downsample = 8


def test_load_data():
for gray in [True, False]:
for dtype in [np.float32, np.float64]:
psf, data = load_data(
psf_fp=psf_fp,
data_fp=data_fp,
downsample=downsample,
plot=False,
gray=gray,
dtype=dtype,
)
if gray:
assert len(psf.shape) == 2
else:
assert len(psf.shape) == 3
assert psf.shape == data.shape
assert psf.dtype == dtype, dtype
assert data.dtype == dtype, dtype


test_load_data()

0 comments on commit 5a9ba47

Please sign in to comment.