Skip to content

Commit

Permalink
Add function for setting PSF, add extra dimension when single image.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jul 11, 2024
1 parent 25890e0 commit 230af02
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions lensless/recon/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,39 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs):
dtype = torch.float32
else:
dtype = np.float32
self.dtype = dtype

self.pad = pad # Whether necessary to pad provided data
self.set_psf(psf)

def _crop(self, x):
return x[
..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], :
]

def _pad(self, v):
if len(v.shape) == 5:
batch_size = v.shape[0]
shape = [batch_size] + self._padded_shape
elif len(v.shape) == 4:
shape = self._padded_shape
else:
raise ValueError("Expected 4D or 5D tensor")

if self.is_torch:
vpad = torch.zeros(size=shape, dtype=v.dtype, device=v.device)
else:
vpad = np.zeros(shape).astype(v.dtype)
vpad[
..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], :
] = v
return vpad

def set_psf(self, psf):
if self.is_torch:
self._psf = psf.type(dtype)
self._psf = psf.type(self.dtype)
else:
self._psf = psf.astype(dtype)
self._psf = psf.astype(self.dtype)

self._psf_shape = np.array(self._psf.shape)

Expand All @@ -87,45 +115,20 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs):
)
self._start_idx = (self._padded_shape[-3:-1] - self._psf_shape[-3:-1]) // 2
self._end_idx = self._start_idx + self._psf_shape[-3:-1]
self.pad = pad # Whether necessary to pad provided data

# precompute filter in frequency domain
if self.is_torch:
self._H = torch.fft.rfft2(
self._pad(self._psf), norm=norm, dim=(-3, -2), s=self._padded_shape[-3:-1]
self._pad(self._psf), norm=self.norm, dim=(-3, -2), s=self._padded_shape[-3:-1]
)
self._Hadj = torch.conj(self._H)
self._padded_data = (
None # This must be reinitialized each time to preserve differentiability
)
else:
self._H = fft.rfft2(self._pad(self._psf), axes=(-3, -2), norm=norm)
self._H = fft.rfft2(self._pad(self._psf), axes=(-3, -2), norm=self.norm)
self._Hadj = np.conj(self._H)
self._padded_data = np.zeros(self._padded_shape).astype(dtype)

self.dtype = dtype

def _crop(self, x):
return x[
..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], :
]

def _pad(self, v):
if len(v.shape) == 5:
batch_size = v.shape[0]
elif len(v.shape) == 4:
batch_size = 1
else:
raise ValueError("Expected 4D or 5D tensor")
shape = [batch_size] + self._padded_shape
if self.is_torch:
vpad = torch.zeros(size=shape, dtype=v.dtype, device=v.device)
else:
vpad = np.zeros(shape).astype(v.dtype)
vpad[
..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], :
] = v
return vpad
self._padded_data = np.zeros(self._padded_shape).astype(self.dtype)

def convolve(self, x):
"""
Expand Down

0 comments on commit 230af02

Please sign in to comment.