From 3bccf776e52a7fa1feb9247b9cf0932387dc39fa Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Mon, 6 Jul 2020 16:56:18 +0200 Subject: [PATCH] Fix Parallel wavelet run by having a queue of transformers (#105) * Added codes and tests for parallel WaveletN * Fix PEP8 Errors Co-authored-by: chaithyagr --- mri/operators/linear/wavelet.py | 42 +++++--- mri/tests/test_wavelet_adjoint.py | 156 ++++++++++++++++-------------- 2 files changed, 111 insertions(+), 87 deletions(-) diff --git a/mri/operators/linear/wavelet.py b/mri/operators/linear/wavelet.py index 184c5566..c7f5d774 100644 --- a/mri/operators/linear/wavelet.py +++ b/mri/operators/linear/wavelet.py @@ -20,6 +20,7 @@ from pysap.base.utils import unflatten # Third party import +import joblib from joblib import Parallel, delayed import numpy as np @@ -52,28 +53,39 @@ def __init__(self, wavelet_name, nb_scale=4, verbose=0, dim=2, self.unflatten = unflatten self.n_jobs = n_jobs self.n_coils = n_coils + if self.n_coils == 1 and self.n_jobs != 1: + print("Making n_jobs = 1 for WaveletN as n_coils = 1") + self.n_jobs = 1 self.backend = backend self.verbose = verbose if wavelet_name not in pysap.AVAILABLE_TRANSFORMS: raise ValueError( "Unknown transformation '{0}'.".format(wavelet_name)) transform_klass = pysap.load_transform(wavelet_name) - self.transform = transform_klass( - nb_scale=self.nb_scale, verbose=verbose, dim=dim, **kwargs) + self.transform_queue = [] + n_proc = self.n_jobs + if n_proc < 0: + n_proc = joblib.cpu_count() + self.n_jobs + 1 + # Create transform queue for parallel execution + for i in range(min(n_proc, self.n_coils)): + self.transform_queue.append(transform_klass( + nb_scale=self.nb_scale, + verbose=verbose, + dim=dim, + **kwargs) + ) self.coeffs_shape = None - def get_coeff(self): - return self.transform.analysis_data - - def set_coeff(self, coeffs): - self.transform.analysis_data = coeffs - def _op(self, data): if isinstance(data, np.ndarray): data = pysap.Image(data=data) - self.transform.data = data - self.transform.analysis() - coeffs, coeffs_shape = flatten(self.transform.analysis_data) + # Get the transform from queue + transform = self.transform_queue.pop() + transform.data = data + transform.analysis() + coeffs, coeffs_shape = flatten(transform.analysis_data) + # Add back the transform to the queue + self.transform_queue.append(transform) return coeffs, coeffs_shape def op(self, data): @@ -124,8 +136,12 @@ def _adj_op(self, coeffs, coeffs_shape, dtype="array"): data: ndarray the reconstructed data. """ - self.transform.analysis_data = unflatten(coeffs, coeffs_shape) - image = self.transform.synthesis() + # Get the transform from queue + transform = self.transform_queue.pop() + transform.analysis_data = unflatten(coeffs, coeffs_shape) + image = transform.synthesis() + # Add back the transform to the queue + self.transform_queue.append(transform) if dtype == "array": return image.data return image diff --git a/mri/tests/test_wavelet_adjoint.py b/mri/tests/test_wavelet_adjoint.py index c3233207..9d83a310 100644 --- a/mri/tests/test_wavelet_adjoint.py +++ b/mri/tests/test_wavelet_adjoint.py @@ -28,101 +28,109 @@ def setUp(self): """ self.N = 64 self.max_iter = 10 - self.num_channels = 10 + self.num_channels = [1, 10] def test_Wavelet2D_ISAP(self): """Test the adjoint operator for the 2D Wavelet transform """ - for i in range(self.max_iter): - print("Process Wavelet2D_ISAP test '{0}'...", i) - wavelet_op_adj = WaveletN(wavelet_name="HaarWaveletTransform", - nb_scale=4) - Img = (np.random.randn(self.N, self.N) + - 1j * np.random.randn(self.N, self.N)) - f_p = wavelet_op_adj.op(Img) - f = (np.random.randn(*f_p.shape) + - 1j * np.random.randn(*f_p.shape)) - I_p = wavelet_op_adj.adj_op(f) - x_d = np.vdot(Img, I_p) - x_ad = np.vdot(f_p, f) - np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) + for ch in self.num_channels: + print("Testing with Num Channels : " + str(ch)) + for i in range(self.max_iter): + print("Process Wavelet2D_ISAP test '{0}'...", i) + wavelet_op_adj = WaveletN( + wavelet_name="HaarWaveletTransform", + nb_scale=4, + n_coils=ch, + n_jobs=2 + ) + Img = np.squeeze(np.random.randn(ch, self.N, self.N) + + 1j * np.random.randn(ch, self.N, self.N)) + f_p = wavelet_op_adj.op(Img) + f = (np.random.randn(*f_p.shape) + + 1j * np.random.randn(*f_p.shape)) + I_p = wavelet_op_adj.adj_op(f) + x_d = np.vdot(Img, I_p) + x_ad = np.vdot(f_p, f) + np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) print(" Wavelet2 adjoint test passes") def test_Wavelet2D_PyWt(self): """Test the adjoint operator for the 2D Wavelet transform """ - for i in range(self.max_iter): - print("Process Wavelet2D PyWt test '{0}'...", i) - wavelet_op_adj = WaveletN(wavelet_name="sym8", - nb_scale=4) - Img = (np.random.randn(self.N, self.N) + - 1j * np.random.randn(self.N, self.N)) - f_p = wavelet_op_adj.op(Img) - f = (np.random.randn(*f_p.shape) + - 1j * np.random.randn(*f_p.shape)) - I_p = wavelet_op_adj.adj_op(f) - x_d = np.vdot(Img, I_p) - x_ad = np.vdot(f_p, f) - np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) + for ch in self.num_channels: + print("Testing with Num Channels : " + str(ch)) + for i in range(self.max_iter): + print("Process Wavelet2D PyWt test '{0}'...", i) + wavelet_op_adj = WaveletN( + wavelet_name="sym8", + nb_scale=4, + n_coils=ch, + n_jobs=2 + ) + Img = np.squeeze( + np.random.randn(ch, self.N, self.N) + + 1j * np.random.randn(ch, self.N, self.N) + ) + f_p = wavelet_op_adj.op(Img) + f = (np.random.randn(*f_p.shape) + + 1j * np.random.randn(*f_p.shape)) + I_p = wavelet_op_adj.adj_op(f) + x_d = np.vdot(Img, I_p) + x_ad = np.vdot(f_p, f) + np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) print(" Wavelet2 adjoint test passes") def test_Wavelet3D_PyWt(self): """Test the adjoint operator for the 3D Wavelet transform """ - for i in range(self.max_iter): - print("Process Wavelet3D PyWt test '{0}'...", i) - wavelet_op_adj = WaveletN(wavelet_name="sym8", - nb_scale=4, dim=3, - padding_mode='periodization') - Img = (np.random.randn(self.N, self.N, self.N) + - 1j * np.random.randn(self.N, self.N, self.N)) - f_p = wavelet_op_adj.op(Img) - f = (np.random.randn(*f_p.shape) + - 1j * np.random.randn(*f_p.shape)) - I_p = wavelet_op_adj.adj_op(f) - x_d = np.vdot(Img, I_p) - x_ad = np.vdot(f_p, f) - np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) + for ch in self.num_channels: + print("Testing with Num Channels : " + str(ch)) + for i in range(self.max_iter): + print("Process Wavelet3D PyWt test '{0}'...", i) + wavelet_op_adj = WaveletN( + wavelet_name="sym8", + nb_scale=4, + dim=3, + padding_mode='periodization', + n_coils=ch, + n_jobs=-1, + ) + Img = np.squeeze( + np.random.randn(ch, self.N, self.N, self.N) + + 1j * np.random.randn(ch, self.N, self.N, self.N) + ) + f_p = wavelet_op_adj.op(Img) + f = (np.random.randn(*f_p.shape) + + 1j * np.random.randn(*f_p.shape)) + I_p = wavelet_op_adj.adj_op(f) + x_d = np.vdot(Img, I_p) + x_ad = np.vdot(f_p, f) + np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) print(" Wavelet3 adjoint test passes") def test_Wavelet_UD_2D(self): """Test the adjoint operation for Undecimated wavelet """ - for i in range(self.max_iter): - print("Process Wavelet Undecimated test '{0}'...", i) - wavelet_op = WaveletUD2(nb_scale=4) - img = (np.random.randn(self.N, self.N) + - 1j * np.random.randn(self.N, self.N)) - f_p = wavelet_op.op(img) - f = (np.random.randn(*f_p.shape) + - 1j * np.random.randn(*f_p.shape)) - i_p = wavelet_op.adj_op(f) - x_d = np.vdot(img, i_p) - x_ad = np.vdot(f_p, f) - np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) + for ch in self.num_channels: + print("Testing with Num Channels : " + str(ch)) + for i in range(self.max_iter): + print("Process Wavelet Undecimated test '{0}'...", i) + wavelet_op = WaveletUD2( + nb_scale=4, + n_coils=ch, + n_jobs=2, + ) + img = np.squeeze(np.random.randn(ch, self.N, self.N) + + 1j * np.random.randn(ch, self.N, self.N)) + f_p = wavelet_op.op(img) + f = (np.random.randn(*f_p.shape) + + 1j * np.random.randn(*f_p.shape)) + i_p = wavelet_op.adj_op(f) + x_d = np.vdot(img, i_p) + x_ad = np.vdot(f_p, f) + np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) print("Undecimated Wavelet 2D adjoint test passes") - def test_Wavelet_UD_2D_Multichannel(self): - """Test the adjoint operation for Undecmated wavelet Transform in - multichannel case""" - for i in range(self.max_iter): - print("Process Wavelet Undecimated test '{0}'...", i) - wavelet_op = WaveletUD2( - nb_scale=4, - n_coils=self.num_channels, - n_jobs=2 - ) - img = (np.random.randn(self.num_channels, self.N, self.N) + - 1j * np.random.randn(self.num_channels, self.N, self.N)) - f_p = wavelet_op.op(img) - f = (np.random.randn(*f_p.shape) + - 1j * np.random.randn(*f_p.shape)) - i_p = wavelet_op.adj_op(f) - x_d = np.vdot(img, i_p) - x_ad = np.vdot(f_p, f) - np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) - print("Undecimated Wavelet 2D adjoint test passes for multichannel") - if __name__ == "__main__": unittest.main()