diff --git a/sigpy/wavelet.py b/sigpy/wavelet.py index ea1cc60f..82fcf26d 100644 --- a/sigpy/wavelet.py +++ b/sigpy/wavelet.py @@ -147,22 +147,25 @@ def iwt(input, oshape, wave_name='db4'): pad_shape = [(oshape[k] + rec_hi.size - 1 + (oshape[k] + rec_hi.size - 1) % 2) if k in axes else oshape[k] \ for k in range(len(oshape))] + inputdct = {} approxdct = {} for key in list(input.keys()): if (int(key[:4]) < max_level): - approxdct[key] = input.pop(key) + approxdct[key] = input[key] + else: + inputdct[key] = input[key] if approxdct != {}: approxkey = [elm for elm in approxdct.keys() if 'H' not in elm].pop() approxkey = "%04d%s" % (max_level, approxkey[4:]) ashape = [pad_shape[k]//2 if k in axes else oshape[k] for k in range(len(oshape))] - input[approxkey] = iwt(approxdct, ashape, wave_name=wave_name) + inputdct[approxkey] = iwt(approxdct, ashape, wave_name=wave_name) res = xp.zeros(oshape) X = xp.zeros(pad_shape).astype(xp.complex64) f = xp.ones(pad_shape).astype(xp.complex64) - for key in input.keys(): - X[sampleidx] = input[key]; + for key in inputdct.keys(): + X[sampleidx] = inputdct[key]; X = xp.fft.fftn(X, axes=axes) subkeys = list(key)[4:] for k in range(len(subkeys)):