diff --git a/benchmarks/benchmarks/cwt_benchmarks.py b/benchmarks/benchmarks/cwt_benchmarks.py index f02b63334..eda4e4f2e 100644 --- a/benchmarks/benchmarks/cwt_benchmarks.py +++ b/benchmarks/benchmarks/cwt_benchmarks.py @@ -20,6 +20,7 @@ def setup(self, n, wavelet, max_scale, dtype, method): except ImportError: raise NotImplementedError("cwt not available") self.data = np.ones(n, dtype=dtype) + self.batch_data = np.ones((5, n), dtype=dtype) self.scales = np.arange(1, max_scale + 1) @@ -33,3 +34,12 @@ def time_cwt(self, n, wavelet, max_scale, dtype, method): raise NotImplementedError( "fft-based convolution not available.") pywt.cwt(self.data, self.scales, wavelet) + + def time_cwt_batch(self, n, wavelet, max_scale, dtype, method): + try: + pywt.cwt(self.batch_data, self.scales, wavelet, method=method, + axis=-1) + except TypeError: + # older PyWavelets does not support the axis argument + raise NotImplementedError( + "axis argument not available.") diff --git a/pywt/_cwt.py b/pywt/_cwt.py index 58d1aceb7..743c3f4a2 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -34,7 +34,7 @@ def next_fast_len(n): return 2**ceil(np.log2(n)) -def cwt(data, scales, wavelet, sampling_period=1., method='conv'): +def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1): """ cwt(data, scales, wavelet) @@ -66,12 +66,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'): The ``fft`` method is ``O(N * log2(N))`` with ``N = len(scale) + len(data) - 1``. It is well suited for large size signals but slightly slower than ``conv`` on small ones. + axis: int, optional + Axis over which to compute the CWT. If not given, the last axis is + used. Returns ------- coefs : array_like Continuous wavelet transform of the input signal for the given scales - and wavelet + and wavelet. The first axis of ``coefs`` corresponds to the scales. + The remaining axes match the shape of ``data``. frequencies : array_like If the unit of sampling period are seconds and given, than frequencies are in hertz. Otherwise, a sampling period of 1 is assumed. @@ -112,62 +116,86 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv'): wavelet = DiscreteContinuousWavelet(wavelet) if np.isscalar(scales): scales = np.array([scales]) - if data.ndim == 1: - dt_out = dt_cplx if wavelet.complex_cwt else dt - out = np.empty((np.size(scales), data.size), dtype=dt_out) - precision = 10 - int_psi, x = integrate_wavelet(wavelet, precision=precision) - - # convert int_psi, x to the same precision as the data - dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt - int_psi = np.asarray(int_psi, dtype=dt_psi) - x = np.asarray(x, dtype=data.real.dtype) - - if method == 'fft': - size_scale0 = -1 - fft_data = None - elif not method == 'conv': - raise ValueError("method must be 'conv' or 'fft'") - - for i, scale in enumerate(scales): - step = x[1] - x[0] - j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) - j = j.astype(int) # floor - if j[-1] >= int_psi.size: - j = np.extract(j < int_psi.size, j) - int_psi_scale = int_psi[j][::-1] - - if method == 'conv': + if not np.isscalar(axis): + raise ValueError("axis must be a scalar.") + + dt_out = dt_cplx if wavelet.complex_cwt else dt + out = np.empty((np.size(scales),) + data.shape, dtype=dt_out) + precision = 10 + int_psi, x = integrate_wavelet(wavelet, precision=precision) + + # convert int_psi, x to the same precision as the data + dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt + int_psi = np.asarray(int_psi, dtype=dt_psi) + x = np.asarray(x, dtype=data.real.dtype) + + if method == 'fft': + size_scale0 = -1 + fft_data = None + elif not method == 'conv': + raise ValueError("method must be 'conv' or 'fft'") + + if data.ndim > 1: + # move axis to be transformed last (so it is contiguous) + data = data.swapaxes(-1, axis) + + # reshape to (n_batch, data.shape[-1]) + data_shape_pre = data.shape + data = data.reshape((-1, data.shape[-1])) + + for i, scale in enumerate(scales): + step = x[1] - x[0] + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + + if method == 'conv': + if data.ndim == 1: conv = np.convolve(data, int_psi_scale) else: - # The padding is selected for: - # - optimal FFT complexity - # - to be larger than the two signals length to avoid circular - # convolution - size_scale = next_fast_len(data.size + int_psi_scale.size - 1) - if size_scale != size_scale0: - # Must recompute fft_data when the padding size changes. - fft_data = fftmodule.fft(data, size_scale) - size_scale0 = size_scale - fft_wav = fftmodule.fft(int_psi_scale, size_scale) - conv = fftmodule.ifft(fft_wav * fft_data) - conv = conv[:data.size + int_psi_scale.size - 1] - - coef = - np.sqrt(scale) * np.diff(conv) - if out.dtype.kind != 'c': - coef = coef.real - d = (coef.size - data.size) / 2. - if d > 0: - out[i, :] = coef[floor(d):-ceil(d)] - elif d == 0.: - out[i, :] = coef - else: - raise ValueError( - "Selected scale of {} too small.".format(scale)) - frequencies = scale2frequency(wavelet, scales, precision) - if np.isscalar(frequencies): - frequencies = np.array([frequencies]) - frequencies /= sampling_period - return out, frequencies - else: - raise ValueError("Only dim == 1 supported") + # batch convolution via loop + conv_shape = list(data.shape) + conv_shape[-1] += int_psi_scale.size - 1 + conv_shape = tuple(conv_shape) + conv = np.empty(conv_shape, dtype=dt_out) + for n in range(data.shape[0]): + conv[n, :] = np.convolve(data[n], int_psi_scale) + else: + # The padding is selected for: + # - optimal FFT complexity + # - to be larger than the two signals length to avoid circular + # convolution + size_scale = next_fast_len( + data.shape[-1] + int_psi_scale.size - 1 + ) + if size_scale != size_scale0: + # Must recompute fft_data when the padding size changes. + fft_data = fftmodule.fft(data, size_scale, axis=-1) + size_scale0 = size_scale + fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1) + conv = fftmodule.ifft(fft_wav * fft_data, axis=-1) + conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1] + + coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + if out.dtype.kind != 'c': + coef = coef.real + # transform axis is always -1 due to the data reshape above + d = (coef.shape[-1] - data.shape[-1]) / 2. + if d > 0: + coef = coef[..., floor(d):-ceil(d)] + elif d < 0: + raise ValueError( + "Selected scale of {} too small.".format(scale)) + if data.ndim > 1: + # restore original data shape and axis position + coef = coef.reshape(data_shape_pre) + coef = coef.swapaxes(axis, -1) + out[i, ...] = coef + + frequencies = scale2frequency(wavelet, scales, precision) + if np.isscalar(frequencies): + frequencies = np.array([frequencies]) + frequencies /= sampling_period + return out, frequencies diff --git a/pywt/tests/test_cwt_wavelets.py b/pywt/tests/test_cwt_wavelets.py index acdc8653a..9dcb65162 100644 --- a/pywt/tests/test_cwt_wavelets.py +++ b/pywt/tests/test_cwt_wavelets.py @@ -1,8 +1,10 @@ #!/usr/bin/env python from __future__ import division, print_function, absolute_import +from itertools import product from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal, assert_raises, assert_equal) +import pytest import numpy as np import pywt @@ -344,29 +346,65 @@ def test_cwt_parameters_in_names(): assert_raises(ValueError, func, 'fbsp1-1-1-1') -def test_cwt_complex(): - for dtype, tol in [(np.float32, 1e-5), (np.float64, 1e-13)]: - time, sst = pywt.data.nino() - sst = np.asarray(sst, dtype=dtype) - dt = time[1] - time[0] - wavelet = 'cmor1.5-1.0' - scales = np.arange(1, 32) - - for method in ['conv', 'fft']: - # real-valued tranfsorm as a reference - [cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method) - - # verify same precision - assert_equal(cfs.real.dtype, sst.dtype) - - # complex-valued transform equals sum of the transforms of the real - # and imaginary components - sst_complex = sst + 1j*sst - [cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt, - method=method) - assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol) - # verify dtype is preserved - assert_equal(cfs_complex.dtype, sst_complex.dtype) +@pytest.mark.parametrize('dtype, tol, method', + [(np.float32, 1e-5, 'conv'), + (np.float32, 1e-5, 'fft'), + (np.float64, 1e-13, 'conv'), + (np.float64, 1e-13, 'fft')]) +def test_cwt_complex(dtype, tol, method): + time, sst = pywt.data.nino() + sst = np.asarray(sst, dtype=dtype) + dt = time[1] - time[0] + wavelet = 'cmor1.5-1.0' + scales = np.arange(1, 32) + + # real-valued tranfsorm as a reference + [cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method) + + # verify same precision + assert_equal(cfs.real.dtype, sst.dtype) + + # complex-valued transform equals sum of the transforms of the real + # and imaginary components + sst_complex = sst + 1j*sst + [cfs_complex, f] = pywt.cwt(sst_complex, scales, wavelet, dt, + method=method) + assert_allclose(cfs + 1j*cfs, cfs_complex, atol=tol, rtol=tol) + # verify dtype is preserved + assert_equal(cfs_complex.dtype, sst_complex.dtype) + + +@pytest.mark.parametrize('axis, method', product([0, 1], ['conv', 'fft'])) +def test_cwt_batch(axis, method): + dtype = np.float64 + time, sst = pywt.data.nino() + n_batch = 8 + batch_axis = 1 - axis + sst1 = np.asarray(sst, dtype=dtype) + sst = np.stack((sst1, ) * n_batch, axis=batch_axis) + dt = time[1] - time[0] + wavelet = 'cmor1.5-1.0' + scales = np.arange(1, 32) + + # non-batch transform as reference + [cfs1, f] = pywt.cwt(sst1, scales, wavelet, dt, method=method, axis=axis) + + shape_in = sst.shape + [cfs, f] = pywt.cwt(sst, scales, wavelet, dt, method=method, axis=axis) + + # shape of input is not modified + assert_equal(shape_in, sst.shape) + + # verify same precision + assert_equal(cfs.real.dtype, sst.dtype) + + # verify expected shape + assert_equal(cfs.shape[0], len(scales)) + assert_equal(cfs.shape[1 + batch_axis], n_batch) + assert_equal(cfs.shape[1 + axis], sst.shape[axis]) + + # batch result on stacked input is the same as stacked 1d result + assert_equal(cfs, np.stack((cfs1,) * n_batch, axis=batch_axis + 1)) def test_cwt_small_scales():