Skip to content

Commit

Permalink
Cwt: Since PyWavelets#570 has not been merged yet, this is to quickly…
Browse files Browse the repository at this point in the history
… implement the precision option
  • Loading branch information
tien-vo committed Aug 13, 2024
1 parent d95a01f commit 1b451e5
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions pywt/_cwt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from math import ceil, floor

from ._extensions._pywt import (
ContinuousWavelet,
DiscreteContinuousWavelet,
Wavelet,
_check_dtype,
)
from ._extensions._pywt import (ContinuousWavelet, DiscreteContinuousWavelet,
Wavelet, _check_dtype)
from ._functions import integrate_wavelet, scale2frequency
from ._utils import AxisError

Expand All @@ -16,6 +12,7 @@

try:
import scipy

fftmodule = scipy.fft
next_fast_len = fftmodule.next_fast_len
except ImportError:
Expand All @@ -31,10 +28,19 @@ def next_fast_len(n):
following this number to take advantage of FFT speedup.
This fallback is less efficient than `scipy.fftpack.next_fast_len`
"""
return 2**ceil(np.log2(n))


def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
return 2 ** ceil(np.log2(n))


def cwt(
data,
scales,
wavelet,
sampling_period=1.0,
method="conv",
axis=-1,
*,
precision=12,
):
"""
cwt(data, scales, wavelet)
Expand Down Expand Up @@ -70,6 +76,11 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
axis: int, optional
Axis over which to compute the CWT. If not given, the last axis is
used.
precision: int, optional
Length of wavelet (2 ** precision) used to compute the CWT. Greater
will increase resolution, especially for lower and higher scales,
but compute a bit slower. Too low will distort coefficients
and their norms, with a zipper-like effect; recommended >= 12.
Returns
-------
Expand Down Expand Up @@ -125,16 +136,15 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):

dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi

# convert int_psi, x to the same precision as the data
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt
int_psi = np.asarray(int_psi, dtype=dt_psi)
x = np.asarray(x, dtype=data.real.dtype)

if method == 'fft':
if method == "fft":
size_scale0 = -1
fft_data = None
elif method != "conv":
Expand All @@ -156,7 +166,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]

if method == 'conv':
if method == "conv":
if data.ndim == 1:
conv = np.convolve(data, int_psi_scale)
else:
Expand All @@ -172,27 +182,24 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
# - optimal FFT complexity
# - to be larger than the two signals length to avoid circular
# convolution
size_scale = next_fast_len(
data.shape[-1] + int_psi_scale.size - 1
)
size_scale = next_fast_len(data.shape[-1] + int_psi_scale.size - 1)
if size_scale != size_scale0:
# Must recompute fft_data when the padding size changes.
fft_data = fftmodule.fft(data, size_scale, axis=-1)
size_scale0 = size_scale
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
conv = conv[..., : data.shape[-1] + int_psi_scale.size - 1]

coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
if out.dtype.kind != 'c':
coef = -np.sqrt(scale) * np.diff(conv, axis=-1)
if out.dtype.kind != "c":
coef = coef.real
# transform axis is always -1 due to the data reshape above
d = (coef.shape[-1] - data.shape[-1]) / 2.
d = (coef.shape[-1] - data.shape[-1]) / 2.0
if d > 0:
coef = coef[..., floor(d):-ceil(d)]
coef = coef[..., floor(d) : -ceil(d)]
elif d < 0:
raise ValueError(
f"Selected scale of {scale} too small.")
raise ValueError(f"Selected scale of {scale} too small.")
if data.ndim > 1:
# restore original data shape and axis position
coef = coef.reshape(data_shape_pre)
Expand Down

0 comments on commit 1b451e5

Please sign in to comment.