Skip to content

Commit

Permalink
Merge pull request #20 from LCAV/fix/reconstruction_shape
Browse files Browse the repository at this point in the history
Fix possible shape mismatch when using RFFT, add RGB APGD example.
  • Loading branch information
ebezzam authored May 31, 2022
2 parents 15b3848 + 88d9ed4 commit 8e4f997
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 44 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

#### Added

-
- Example of RGB reconstruction with complex-valued FFT: `scripts/recon/apgd_pycsou.py`

#### Changed

-

#### Bugfix

-
- Possible shape mismatch when using the real-valued FFT: forward and backward.

## 1.0.1 - (2022-04-26)

Expand Down
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- [Remote capture](#capture).
- [Remote display](#display).
- [Collecting MNIST](#mnist).
- [Citing this work](#cite).

This package provides functionalities to perform imaging with a lensless camera.
We make use of a low-cost implementation of the DiffuserCam [[1]](#1) where we
Expand Down Expand Up @@ -250,6 +251,19 @@ back and forth):
python scripts/collect_mnist.py --hostname <IP_ADDRESS> --output_dir MNIST_meas
```

## Citing this work <a name="cite"></a>

If you use these tools in your own research, please cite the following:
```
@misc{lenslesspicam,
url = {https://infoscience.epfl.ch/record/294041?&ln=en},
author = {Bezzam, Eric and Kashani, Sepand and Vetterli, Martin and Simeoni, Matthieu},
title = {Lensless{P}i{C}am: A Hardware and Software Platform for Lensless Computational Imaging with a {R}aspberry {P}i},
publisher = {Infoscience},
year = {2022},
}
```

## References
<a id="1">[1]</a>
Antipa, N., Kuo, G., Heckel, R., Mildenhall, B., Bostan, E., Ng, R., & Waller, L. (2018). DiffuserCam: lensless single-exposure 3D imaging. Optica, 5(1), 1-9.
Expand Down
1 change: 1 addition & 0 deletions format_code.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
black *.py -l 100
black lensless/*.py -l 100
black scripts/*.py -l 100
black scripts/recon/*.py -l 100
black profile/*.py -l 100
black test/*.py -l 100
21 changes: 15 additions & 6 deletions lensless/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
psi=None,
psi_adj=None,
psi_gram=None,
**kwargs
):
"""
Expand Down Expand Up @@ -109,26 +110,34 @@ def _forward(self):
"""Convolution with frequency response."""
return fft.ifftshift(
fft.irfft2(
fft.rfft2(self._image_est, axes=(0, 1)) * self._H,
fft.rfft2(self._image_est, axes=(0, 1), s=self._padded_shape[:2]) * self._H,
axes=(0, 1),
s=self._padded_shape[:2],
),
axes=(0, 1),
)

def _backward(self, x):
"""adjoint of forward / convolution"""
return fft.ifftshift(
fft.irfft2(fft.rfft2(x, axes=(0, 1)) * np.conj(self._H), axes=(0, 1)),
fft.irfft2(
fft.rfft2(x, axes=(0, 1), s=self._padded_shape[:2]) * np.conj(self._H),
axes=(0, 1),
s=self._padded_shape[:2],
),
axes=(0, 1),
)

def reset(self):
# spatial frequency response
self._H = fft.rfft2(self._pad(self._psf), axes=(0, 1)).astype(self._complex_dtype)
self._H = fft.rfft2(self._pad(self._psf), axes=(0, 1), s=self._padded_shape[:2]).astype(
self._complex_dtype
)

self._X = np.zeros(self._padded_shape, dtype=self._dtype)
self._U = np.zeros(np.r_[self._padded_shape, [2]], dtype=self._dtype)
# self._U = np.zeros(np.r_[self._padded_shape, [2]], dtype=self._dtype)
self._image_est = np.zeros_like(self._X)
self._U = np.zeros_like(self._Psi(self._image_est))
self._W = np.zeros_like(self._X)
if self._image_est.max():
# if non-zero
Expand Down Expand Up @@ -160,8 +169,8 @@ def _image_update(self):
+ self._PsiT(self._mu2 * self._U - self._eta)
+ self._backward(self._mu1 * self._X - self._xi)
)
freq_space_result = self._R_divmat * fft.rfft2(rk, axes=(0, 1))
self._image_est = fft.irfft2(freq_space_result, axes=(0, 1))
freq_space_result = self._R_divmat * fft.rfft2(rk, axes=(0, 1), s=self._padded_shape[:2])
self._image_est = fft.irfft2(freq_space_result, axes=(0, 1), s=self._padded_shape[:2])

def _W_update(self):
"""Non-negativity update"""
Expand Down
1 change: 1 addition & 0 deletions lensless/apgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
diff_lambda=0.001,
prox_lambda=0.001,
realconv=True,
**kwargs
):
"""
Wrapper for Pycsou's APGD (accelerated proximal gradient descent)
Expand Down
24 changes: 16 additions & 8 deletions lensless/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def non_neg(xi):


class GradientDescient(ReconstructionAlgorithm):
def __init__(self, psf, dtype=np.float32, proj=non_neg):
def __init__(self, psf, dtype=np.float32, proj=non_neg, **kwargs):
"""
Object for applying projected gradient descent.
Expand Down Expand Up @@ -81,7 +81,9 @@ def reset(self):
self._image_est = self._pad(x)

# spatial frequency response
self._H = fft.rfft2(self._pad(self._psf), norm="ortho", axes=(0, 1))
self._H = fft.rfft2(
self._pad(self._psf), norm="ortho", axes=(0, 1), s=self._padded_shape[:2]
)
self._Hadj = np.conj(self._H)

Hadj_flat = self._Hadj.reshape(-1, self._n_channels)
Expand All @@ -93,12 +95,18 @@ def _grad(self):
return self._backward(diff)

def _forward(self):
Vk = fft.rfft2(self._image_est, axes=(0, 1))
return self._crop(fft.ifftshift(fft.irfft2(self._H * Vk, axes=(0, 1)), axes=(0, 1)))
Vk = fft.rfft2(self._image_est, axes=(0, 1), s=self._padded_shape[:2])
return self._crop(
fft.ifftshift(
fft.irfft2(self._H * Vk, axes=(0, 1), s=self._padded_shape[:2]), axes=(0, 1)
)
)

def _backward(self, x):
X = fft.rfft2(self._pad(x), axes=(0, 1))
return fft.ifftshift(fft.irfft2(self._Hadj * X, axes=(0, 1)), axes=(0, 1))
X = fft.rfft2(self._pad(x), axes=(0, 1), s=self._padded_shape[:2])
return fft.ifftshift(
fft.irfft2(self._Hadj * X, axes=(0, 1), s=self._padded_shape[:2]), axes=(0, 1)
)

def _update(self):
self._image_est -= self._alpha * self._grad()
Expand All @@ -117,7 +125,7 @@ class NesterovGradientDescent(GradientDescient):
"""

def __init__(self, psf, dtype=np.float32, proj=non_neg, p=0, mu=0.9):
def __init__(self, psf, dtype=np.float32, proj=non_neg, p=0, mu=0.9, **kwargs):
self._p = p
self._mu = mu
super(NesterovGradientDescent, self).__init__(psf, dtype, proj)
Expand All @@ -143,7 +151,7 @@ class FISTA(GradientDescient):
"""

def __init__(self, psf, dtype=np.float32, proj=non_neg, tk=1):
def __init__(self, psf, dtype=np.float32, proj=non_neg, tk=1, **kwargs):

super(FISTA, self).__init__(psf, dtype, proj)
self._tk = tk
Expand Down
4 changes: 3 additions & 1 deletion lensless/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def load_psf(
def load_data(
psf_fp,
data_fp,
downsample,
downsample=None,
bg_pix=(5, 25),
plot=True,
flip=False,
Expand Down Expand Up @@ -277,6 +277,8 @@ def load_data(

assert os.path.isfile(psf_fp)
assert os.path.isfile(data_fp)
if shape is None:
assert downsample is not None

# load and process PSF data
psf, bg = load_psf(
Expand Down
100 changes: 77 additions & 23 deletions scripts/recon/apgd_pycsou.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
"""
Apply Accelerated Proximal Gradient Descent (APDG) with a non-negativity prior
for grayscale reconstruction. Example using Pycsou:
Apply Accelerated Proximal Gradient Descent (APGD) with a desired prior.
Pycsou documentation of (APGD):
https://matthieumeo.github.io/pycsou/html/api/algorithms/pycsou.opt.proxalgs.html?highlight=apgd#pycsou.opt.proxalgs.AcceleratedProximalGradientDescent
Example
Example (default to non-negativity prior):
```
python scripts/recon/apgd_pycsou.py --psf_fp data/psf/tape_rgb.png \
--data_fp data/raw_data/thumbs_up_rgb.png --real_conv
python scripts/recon/apgd_pycsou.py --psf_fp data/psf/tape_rgb.png --data_fp \
data/raw_data/thumbs_up_rgb.png
```
Note that the `RealFFTConvolve2D` has to be implemented in `lensless/realfftconv.py`.
Note that RGB reconstruction will not plot intermediate results as each channel
is solved separately.
Otherwise, grayscale reconstruction with the non-optimized FFT convolution can
be readily used (RGB is not supported):
A faster approach can be applied by implementing `RealFFTConvolve2D` such that
the real-valued FFT is used and the 2-D FFT simulateneously applied across
channels
```
python scripts/recon/apgd_pycsou.py --psf_fp data/psf/tape_rgb.png --data_fp \
data/raw_data/thumbs_up_rgb.png --gray
python scripts/recon/apgd_pycsou.py --psf_fp data/psf/tape_rgb.png \
--data_fp data/raw_data/thumbs_up_rgb.png --real_conv
```
Note that `RealFFTConvolve2D` has to be implemented in `lensless/realfftconv.py`.
If you are an instructor and/or would like the solution, please send an email to
eric[dot]bezzam[at]epfl[dot]ch.
"""

Expand All @@ -26,6 +32,7 @@
import click
import matplotlib.pyplot as plt
from lensless.io import load_data
from lensless.plot import plot_image
from lensless import APGD, APGDPriors
import os
import pathlib as plib
Expand Down Expand Up @@ -141,7 +148,7 @@ def apgd(
no_plot,
single_psf,
real_conv,
shape
shape,
):

plot = not no_plot
Expand All @@ -157,7 +164,7 @@ def apgd(
gamma=gamma,
gray=gray,
single_psf=single_psf,
shape=shape
shape=shape,
)

if save:
Expand All @@ -167,26 +174,73 @@ def apgd(
save = plib.Path(__file__).parent / save
save.mkdir(exist_ok=False)

start_time = time.time()
if prior == APGDPriors.L2:
recon = APGD(
psf=psf, max_iter=max_iter, diff_penalty=prior, prox_penalty=None, realconv=real_conv
)
diff_penalty = prior
prox_penalty = None
else:
diff_penalty = None
prox_penalty = prior

start_time = time.time()

if real_conv or gray:

# for `real_conv` parallelize RGB channels with custom operator
recon = APGD(
psf=psf, max_iter=max_iter, diff_penalty=None, prox_penalty=prior, realconv=real_conv
psf=psf,
max_iter=max_iter,
diff_penalty=diff_penalty,
prox_penalty=prox_penalty,
realconv=real_conv,
)
recon.set_data(data)
print(f"Setup time : {time.time() - start_time} s")
recon.set_data(data)
print(f"Setup time : {time.time() - start_time} s")

start_time = time.time()
res = recon.apply(n_iter=max_iter, disp_iter=disp, save=save, gamma=gamma, plot=not no_plot)
print(f"Processing time : {time.time() - start_time} s")
start_time = time.time()
res = recon.apply(n_iter=max_iter, disp_iter=disp, save=save, gamma=gamma, plot=not no_plot)
print(f"Processing time : {time.time() - start_time} s")

final_img = res[0]

else:

# loop over RGB channels (naive approach with complex-valued FFT)
recon = [
APGD(
psf=psf[:, :, i],
max_iter=max_iter,
diff_penalty=diff_penalty,
prox_penalty=prox_penalty,
realconv=real_conv,
)
for i in range(psf.shape[2])
]
[recon[i].set_data(data[:, :, i]) for i in range(data.shape[2])]
print(f"Setup time : {time.time() - start_time} s")

start_time = time.time()
final_img = []
print("Looping over channels...")
for i in range(data.shape[2]):
print(f"-- channel {i}", end="")
final_img.append(
recon[i].apply(
n_iter=max_iter, disp_iter=max_iter + 1, save=False, gamma=gamma, plot=False
)
)
print(f", {time.time() - start_time} s")
print(f"Processing time : {time.time() - start_time} s")

final_img = np.transpose(np.array(final_img), (1, 2, 0))
ax = plot_image(final_img, gamma=gamma)
ax.set_title("Final reconstruction after {} iterations".format(max_iter))
if save:
plt.savefig(plib.Path(save) / "final_reconstruction.png")

if not no_plot:
plt.show()
if save:
np.save(plib.Path(save) / "final_reconstruction.npy", res[0])
np.save(plib.Path(save) / "final_reconstruction.npy", final_img)
print(f"Files saved to : {save}")


Expand Down
2 changes: 1 addition & 1 deletion scripts/recon/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def gradient_descent(
save,
no_plot,
single_psf,
shape
shape,
):

psf, data = load_data(
Expand Down
4 changes: 2 additions & 2 deletions scripts/recon/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def reconstruction(
save,
no_plot,
single_psf,
shape
shape,
):
psf, data = load_data(
psf_fp=psf_fp,
Expand All @@ -130,7 +130,7 @@ def reconstruction(
gamma=gamma,
gray=gray,
single_psf=single_psf,
shape=shape
shape=shape,
)

if save:
Expand Down
5 changes: 4 additions & 1 deletion test/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def test_algo():
gray=gray,
dtype=dtype,
)
recon = algo(psf, dtype=dtype)
if algo == APGD:
if not gray:
continue
recon = algo(psf, dtype=dtype, realconv=False)
recon.set_data(data)
res = recon.apply(n_iter=n_iter, disp_iter=None, plot=False)
if gray:
Expand Down

0 comments on commit 8e4f997

Please sign in to comment.