From 1b451e5e969ee2f3c5055f5ff42fa6d9e3ce8957 Mon Sep 17 00:00:00 2001 From: Tien Vo Date: Tue, 13 Aug 2024 09:20:28 -0600 Subject: [PATCH] Cwt: Since #570 has not been merged yet, this is to quickly implement the precision option --- pywt/_cwt.py | 55 +++++++++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/pywt/_cwt.py b/pywt/_cwt.py index 5239e0e0..1e44cf30 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -1,11 +1,7 @@ from math import ceil, floor -from ._extensions._pywt import ( - ContinuousWavelet, - DiscreteContinuousWavelet, - Wavelet, - _check_dtype, -) +from ._extensions._pywt import (ContinuousWavelet, DiscreteContinuousWavelet, + Wavelet, _check_dtype) from ._functions import integrate_wavelet, scale2frequency from ._utils import AxisError @@ -16,6 +12,7 @@ try: import scipy + fftmodule = scipy.fft next_fast_len = fftmodule.next_fast_len except ImportError: @@ -31,10 +28,19 @@ def next_fast_len(n): following this number to take advantage of FFT speedup. This fallback is less efficient than `scipy.fftpack.next_fast_len` """ - return 2**ceil(np.log2(n)) - - -def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): + return 2 ** ceil(np.log2(n)) + + +def cwt( + data, + scales, + wavelet, + sampling_period=1.0, + method="conv", + axis=-1, + *, + precision=12, +): """ cwt(data, scales, wavelet) @@ -70,6 +76,11 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): axis: int, optional Axis over which to compute the CWT. If not given, the last axis is used. + precision: int, optional + Length of wavelet (2 ** precision) used to compute the CWT. Greater + will increase resolution, especially for lower and higher scales, + but compute a bit slower. Too low will distort coefficients + and their norms, with a zipper-like effect; recommended >= 12. Returns ------- @@ -125,16 +136,15 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): dt_out = dt_cplx if wavelet.complex_cwt else dt out = np.empty((np.size(scales),) + data.shape, dtype=dt_out) - precision = 10 int_psi, x = integrate_wavelet(wavelet, precision=precision) int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi # convert int_psi, x to the same precision as the data - dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt + dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt int_psi = np.asarray(int_psi, dtype=dt_psi) x = np.asarray(x, dtype=data.real.dtype) - if method == 'fft': + if method == "fft": size_scale0 = -1 fft_data = None elif method != "conv": @@ -156,7 +166,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): j = np.extract(j < int_psi.size, j) int_psi_scale = int_psi[j][::-1] - if method == 'conv': + if method == "conv": if data.ndim == 1: conv = np.convolve(data, int_psi_scale) else: @@ -172,27 +182,24 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): # - optimal FFT complexity # - to be larger than the two signals length to avoid circular # convolution - size_scale = next_fast_len( - data.shape[-1] + int_psi_scale.size - 1 - ) + size_scale = next_fast_len(data.shape[-1] + int_psi_scale.size - 1) if size_scale != size_scale0: # Must recompute fft_data when the padding size changes. fft_data = fftmodule.fft(data, size_scale, axis=-1) size_scale0 = size_scale fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1) conv = fftmodule.ifft(fft_wav * fft_data, axis=-1) - conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1] + conv = conv[..., : data.shape[-1] + int_psi_scale.size - 1] - coef = - np.sqrt(scale) * np.diff(conv, axis=-1) - if out.dtype.kind != 'c': + coef = -np.sqrt(scale) * np.diff(conv, axis=-1) + if out.dtype.kind != "c": coef = coef.real # transform axis is always -1 due to the data reshape above - d = (coef.shape[-1] - data.shape[-1]) / 2. + d = (coef.shape[-1] - data.shape[-1]) / 2.0 if d > 0: - coef = coef[..., floor(d):-ceil(d)] + coef = coef[..., floor(d) : -ceil(d)] elif d < 0: - raise ValueError( - f"Selected scale of {scale} too small.") + raise ValueError(f"Selected scale of {scale} too small.") if data.ndim > 1: # restore original data shape and axis position coef = coef.reshape(data_shape_pre)