From 230af029275403cd05a614a9140b436981a9b889 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Thu, 11 Jul 2024 13:09:06 +0000 Subject: [PATCH] Add function for setting PSF, add extra dimension when single image. --- lensless/recon/rfft_convolve.py | 63 +++++++++++++++++---------------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index c0a58236..e7b9be74 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -71,11 +71,39 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs): dtype = torch.float32 else: dtype = np.float32 + self.dtype = dtype + + self.pad = pad # Whether necessary to pad provided data + self.set_psf(psf) + + def _crop(self, x): + return x[ + ..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], : + ] + + def _pad(self, v): + if len(v.shape) == 5: + batch_size = v.shape[0] + shape = [batch_size] + self._padded_shape + elif len(v.shape) == 4: + shape = self._padded_shape + else: + raise ValueError("Expected 4D or 5D tensor") + + if self.is_torch: + vpad = torch.zeros(size=shape, dtype=v.dtype, device=v.device) + else: + vpad = np.zeros(shape).astype(v.dtype) + vpad[ + ..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], : + ] = v + return vpad + def set_psf(self, psf): if self.is_torch: - self._psf = psf.type(dtype) + self._psf = psf.type(self.dtype) else: - self._psf = psf.astype(dtype) + self._psf = psf.astype(self.dtype) self._psf_shape = np.array(self._psf.shape) @@ -87,45 +115,20 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs): ) self._start_idx = (self._padded_shape[-3:-1] - self._psf_shape[-3:-1]) // 2 self._end_idx = self._start_idx + self._psf_shape[-3:-1] - self.pad = pad # Whether necessary to pad provided data # precompute filter in frequency domain if self.is_torch: self._H = torch.fft.rfft2( - self._pad(self._psf), norm=norm, dim=(-3, -2), s=self._padded_shape[-3:-1] + self._pad(self._psf), norm=self.norm, dim=(-3, -2), s=self._padded_shape[-3:-1] ) self._Hadj = torch.conj(self._H) self._padded_data = ( None # This must be reinitialized each time to preserve differentiability ) else: - self._H = fft.rfft2(self._pad(self._psf), axes=(-3, -2), norm=norm) + self._H = fft.rfft2(self._pad(self._psf), axes=(-3, -2), norm=self.norm) self._Hadj = np.conj(self._H) - self._padded_data = np.zeros(self._padded_shape).astype(dtype) - - self.dtype = dtype - - def _crop(self, x): - return x[ - ..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], : - ] - - def _pad(self, v): - if len(v.shape) == 5: - batch_size = v.shape[0] - elif len(v.shape) == 4: - batch_size = 1 - else: - raise ValueError("Expected 4D or 5D tensor") - shape = [batch_size] + self._padded_shape - if self.is_torch: - vpad = torch.zeros(size=shape, dtype=v.dtype, device=v.device) - else: - vpad = np.zeros(shape).astype(v.dtype) - vpad[ - ..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], : - ] = v - return vpad + self._padded_data = np.zeros(self._padded_shape).astype(self.dtype) def convolve(self, x): """