Skip to content

Commit

Permalink
Fix Parallel wavelet run by having a queue of transformers (#105)
Browse files Browse the repository at this point in the history
* Added codes and tests for parallel WaveletN

* Fix PEP8 Errors

Co-authored-by: chaithyagr <[email protected]>
  • Loading branch information
chaithyagr and chaithyagr authored Jul 6, 2020
1 parent db04fe2 commit 3bccf77
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 87 deletions.
42 changes: 29 additions & 13 deletions mri/operators/linear/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pysap.base.utils import unflatten

# Third party import
import joblib
from joblib import Parallel, delayed
import numpy as np

Expand Down Expand Up @@ -52,28 +53,39 @@ def __init__(self, wavelet_name, nb_scale=4, verbose=0, dim=2,
self.unflatten = unflatten
self.n_jobs = n_jobs
self.n_coils = n_coils
if self.n_coils == 1 and self.n_jobs != 1:
print("Making n_jobs = 1 for WaveletN as n_coils = 1")
self.n_jobs = 1
self.backend = backend
self.verbose = verbose
if wavelet_name not in pysap.AVAILABLE_TRANSFORMS:
raise ValueError(
"Unknown transformation '{0}'.".format(wavelet_name))
transform_klass = pysap.load_transform(wavelet_name)
self.transform = transform_klass(
nb_scale=self.nb_scale, verbose=verbose, dim=dim, **kwargs)
self.transform_queue = []
n_proc = self.n_jobs
if n_proc < 0:
n_proc = joblib.cpu_count() + self.n_jobs + 1
# Create transform queue for parallel execution
for i in range(min(n_proc, self.n_coils)):
self.transform_queue.append(transform_klass(
nb_scale=self.nb_scale,
verbose=verbose,
dim=dim,
**kwargs)
)
self.coeffs_shape = None

def get_coeff(self):
return self.transform.analysis_data

def set_coeff(self, coeffs):
self.transform.analysis_data = coeffs

def _op(self, data):
if isinstance(data, np.ndarray):
data = pysap.Image(data=data)
self.transform.data = data
self.transform.analysis()
coeffs, coeffs_shape = flatten(self.transform.analysis_data)
# Get the transform from queue
transform = self.transform_queue.pop()
transform.data = data
transform.analysis()
coeffs, coeffs_shape = flatten(transform.analysis_data)
# Add back the transform to the queue
self.transform_queue.append(transform)
return coeffs, coeffs_shape

def op(self, data):
Expand Down Expand Up @@ -124,8 +136,12 @@ def _adj_op(self, coeffs, coeffs_shape, dtype="array"):
data: ndarray
the reconstructed data.
"""
self.transform.analysis_data = unflatten(coeffs, coeffs_shape)
image = self.transform.synthesis()
# Get the transform from queue
transform = self.transform_queue.pop()
transform.analysis_data = unflatten(coeffs, coeffs_shape)
image = transform.synthesis()
# Add back the transform to the queue
self.transform_queue.append(transform)
if dtype == "array":
return image.data
return image
Expand Down
156 changes: 82 additions & 74 deletions mri/tests/test_wavelet_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,101 +28,109 @@ def setUp(self):
"""
self.N = 64
self.max_iter = 10
self.num_channels = 10
self.num_channels = [1, 10]

def test_Wavelet2D_ISAP(self):
"""Test the adjoint operator for the 2D Wavelet transform
"""
for i in range(self.max_iter):
print("Process Wavelet2D_ISAP test '{0}'...", i)
wavelet_op_adj = WaveletN(wavelet_name="HaarWaveletTransform",
nb_scale=4)
Img = (np.random.randn(self.N, self.N) +
1j * np.random.randn(self.N, self.N))
f_p = wavelet_op_adj.op(Img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
I_p = wavelet_op_adj.adj_op(f)
x_d = np.vdot(Img, I_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
for ch in self.num_channels:
print("Testing with Num Channels : " + str(ch))
for i in range(self.max_iter):
print("Process Wavelet2D_ISAP test '{0}'...", i)
wavelet_op_adj = WaveletN(
wavelet_name="HaarWaveletTransform",
nb_scale=4,
n_coils=ch,
n_jobs=2
)
Img = np.squeeze(np.random.randn(ch, self.N, self.N) +
1j * np.random.randn(ch, self.N, self.N))
f_p = wavelet_op_adj.op(Img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
I_p = wavelet_op_adj.adj_op(f)
x_d = np.vdot(Img, I_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
print(" Wavelet2 adjoint test passes")

def test_Wavelet2D_PyWt(self):
"""Test the adjoint operator for the 2D Wavelet transform
"""
for i in range(self.max_iter):
print("Process Wavelet2D PyWt test '{0}'...", i)
wavelet_op_adj = WaveletN(wavelet_name="sym8",
nb_scale=4)
Img = (np.random.randn(self.N, self.N) +
1j * np.random.randn(self.N, self.N))
f_p = wavelet_op_adj.op(Img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
I_p = wavelet_op_adj.adj_op(f)
x_d = np.vdot(Img, I_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
for ch in self.num_channels:
print("Testing with Num Channels : " + str(ch))
for i in range(self.max_iter):
print("Process Wavelet2D PyWt test '{0}'...", i)
wavelet_op_adj = WaveletN(
wavelet_name="sym8",
nb_scale=4,
n_coils=ch,
n_jobs=2
)
Img = np.squeeze(
np.random.randn(ch, self.N, self.N) +
1j * np.random.randn(ch, self.N, self.N)
)
f_p = wavelet_op_adj.op(Img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
I_p = wavelet_op_adj.adj_op(f)
x_d = np.vdot(Img, I_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
print(" Wavelet2 adjoint test passes")

def test_Wavelet3D_PyWt(self):
"""Test the adjoint operator for the 3D Wavelet transform
"""
for i in range(self.max_iter):
print("Process Wavelet3D PyWt test '{0}'...", i)
wavelet_op_adj = WaveletN(wavelet_name="sym8",
nb_scale=4, dim=3,
padding_mode='periodization')
Img = (np.random.randn(self.N, self.N, self.N) +
1j * np.random.randn(self.N, self.N, self.N))
f_p = wavelet_op_adj.op(Img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
I_p = wavelet_op_adj.adj_op(f)
x_d = np.vdot(Img, I_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
for ch in self.num_channels:
print("Testing with Num Channels : " + str(ch))
for i in range(self.max_iter):
print("Process Wavelet3D PyWt test '{0}'...", i)
wavelet_op_adj = WaveletN(
wavelet_name="sym8",
nb_scale=4,
dim=3,
padding_mode='periodization',
n_coils=ch,
n_jobs=-1,
)
Img = np.squeeze(
np.random.randn(ch, self.N, self.N, self.N) +
1j * np.random.randn(ch, self.N, self.N, self.N)
)
f_p = wavelet_op_adj.op(Img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
I_p = wavelet_op_adj.adj_op(f)
x_d = np.vdot(Img, I_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
print(" Wavelet3 adjoint test passes")

def test_Wavelet_UD_2D(self):
"""Test the adjoint operation for Undecimated wavelet
"""
for i in range(self.max_iter):
print("Process Wavelet Undecimated test '{0}'...", i)
wavelet_op = WaveletUD2(nb_scale=4)
img = (np.random.randn(self.N, self.N) +
1j * np.random.randn(self.N, self.N))
f_p = wavelet_op.op(img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
i_p = wavelet_op.adj_op(f)
x_d = np.vdot(img, i_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
for ch in self.num_channels:
print("Testing with Num Channels : " + str(ch))
for i in range(self.max_iter):
print("Process Wavelet Undecimated test '{0}'...", i)
wavelet_op = WaveletUD2(
nb_scale=4,
n_coils=ch,
n_jobs=2,
)
img = np.squeeze(np.random.randn(ch, self.N, self.N) +
1j * np.random.randn(ch, self.N, self.N))
f_p = wavelet_op.op(img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
i_p = wavelet_op.adj_op(f)
x_d = np.vdot(img, i_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
print("Undecimated Wavelet 2D adjoint test passes")

def test_Wavelet_UD_2D_Multichannel(self):
"""Test the adjoint operation for Undecmated wavelet Transform in
multichannel case"""
for i in range(self.max_iter):
print("Process Wavelet Undecimated test '{0}'...", i)
wavelet_op = WaveletUD2(
nb_scale=4,
n_coils=self.num_channels,
n_jobs=2
)
img = (np.random.randn(self.num_channels, self.N, self.N) +
1j * np.random.randn(self.num_channels, self.N, self.N))
f_p = wavelet_op.op(img)
f = (np.random.randn(*f_p.shape) +
1j * np.random.randn(*f_p.shape))
i_p = wavelet_op.adj_op(f)
x_d = np.vdot(img, i_p)
x_ad = np.vdot(f_p, f)
np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
print("Undecimated Wavelet 2D adjoint test passes for multichannel")


if __name__ == "__main__":
unittest.main()

0 comments on commit 3bccf77

Please sign in to comment.