From b797160e1acec9460b47dbdc459ae218d17655da Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 7 Aug 2024 16:51:11 +0000 Subject: [PATCH] Formatting. --- lensless/recon/multi_wiener.py | 5 +++-- lensless/recon/trainable_inversion.py | 2 -- lensless/recon/trainable_recon.py | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lensless/recon/multi_wiener.py b/lensless/recon/multi_wiener.py index cb53d5a8..cadad0c4 100644 --- a/lensless/recon/multi_wiener.py +++ b/lensless/recon/multi_wiener.py @@ -96,7 +96,7 @@ def __init__( skip_pre=False, ): """ - Constructor for Multi-Wiener Deconvolution Network (MWDN) as proposed in: + Constructor for Multi-Wiener Deconvolution Network (MWDN) as proposed in: https://opg.optica.org/oe/fulltext.cfm?uri=oe-31-23-39088&id=541387 Parameters @@ -169,7 +169,7 @@ def __init__( def _prepare_process_block(self, process): """ Method for preparing the pre or post process block. - + Parameters ---------- process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional @@ -283,6 +283,7 @@ def reconstruction_error(self, prediction, lensless): convolver = self._convolver if not convolver.pad: prediction = convolver._pad(prediction) + Fx = convolver.convolve(prediction) Fy = lensless diff --git a/lensless/recon/trainable_inversion.py b/lensless/recon/trainable_inversion.py index a4e82cf0..6dba7880 100644 --- a/lensless/recon/trainable_inversion.py +++ b/lensless/recon/trainable_inversion.py @@ -9,8 +9,6 @@ class TrainableInversion(TrainableReconstructionAlgorithm): - """ """ - def __init__(self, psf, dtype=None, K=1e-4, **kwargs): """ Constructor for trainable inversion component as proposed in diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 77474730..f106633d 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -92,6 +92,7 @@ def __init__( compensation_residual : bool, optional Whether to use residual connection in compensation layer. """ + assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor" super(TrainableReconstructionAlgorithm, self).__init__( psf, dtype=dtype, n_iter=n_iter, **kwargs