diff --git a/.travis.yml b/.travis.yml index 0965345f5..11df85f61 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,12 +19,15 @@ matrix: - CYTHONSPEC=cython - USE_WHEEL=1 - os: linux - python: 3.7-dev + python: 3.7 + dist: xenial # travis-ci/travis-ci/issues/9815 + sudo: true env: - NUMPYSPEC=numpy - MATPLOTLIBSPEC=matplotlib - CYTHONSPEC=cython - USE_SDIST=1 + - USE_SCIPY=1 - os: linux python: 3.5 env: @@ -59,6 +62,7 @@ before_install: - pip install pytest pytest-cov coverage codecov futures - set -o pipefail - if [ "${USE_WHEEL}" == "1" ]; then pip install wheel; fi + - if [ "${USE_SCIPY}" == "1" ]; then pip install scipy; fi - | if [ "${REFGUIDE_CHECK}" == "1" ]; then pip install sphinx numpydoc diff --git a/README.rst b/README.rst index 3cde6e1b6..d6d1cb2c0 100644 --- a/README.rst +++ b/README.rst @@ -65,9 +65,11 @@ For more usage examples see the `demo`_ directory in the source package. Installation ------------ -PyWavelets supports `Python`_ >=3.5, and is only dependent on `Numpy`_ +PyWavelets supports `Python`_ >=3.5, and is only dependent on `NumPy`_ (supported versions are currently ``>= 1.13.3``). To pass all of the tests, -`Matplotlib`_ is also required. +`Matplotlib`_ is also required. `SciPy`_ is also an optional dependency. When +present, FFT-based continuous wavelet transforms will use FFTs from SciPy +rather than NumPy. There are binary wheels for Intel Linux, Windows and macOS / OSX on PyPi. If you are on one of these platforms, you should get a binary (precompiled) @@ -138,7 +140,8 @@ If you wish to cite PyWavelets in a publication, you may use the following DOI. .. _Anaconda: https://www.continuum.io .. _GitHub: https://github.com/PyWavelets/pywt .. _GitHub Issues: https://github.com/PyWavelets/pywt/issues -.. _Numpy: http://www.numpy.org +.. _NumPy: https://www.numpy.org +.. _SciPy: https://www.scipy.org .. _original developer: http://en.ig.ma .. _Python: http://python.org/ .. _Python Package Index: http://pypi.python.org/pypi/PyWavelets/ diff --git a/benchmarks/benchmarks/cwt_benchmarks.py b/benchmarks/benchmarks/cwt_benchmarks.py index cf9cacd35..0c063b3bf 100644 --- a/benchmarks/benchmarks/cwt_benchmarks.py +++ b/benchmarks/benchmarks/cwt_benchmarks.py @@ -6,12 +6,13 @@ class CwtTimeSuiteBase(object): """ Set-up for CWT timing. """ - params = ([32, 128, 512], + params = ([32, 128, 512, 2048], ['cmor', 'cgau4', 'fbsp', 'gaus4', 'mexh', 'morl', 'shan'], - [16, 64, 256]) - param_names = ('n', 'wavelet', 'max_scale') + [16, 64, 256], + ['conv', 'fft']) + param_names = ('n', 'wavelet', 'max_scale', 'method') - def setup(self, n, wavelet, max_scale): + def setup(self, n, wavelet, max_scale, method): try: from pywt import cwt except ImportError: @@ -21,5 +22,12 @@ def setup(self, n, wavelet, max_scale): class CwtTimeSuite(CwtTimeSuiteBase): - def time_cwt(self, n, wavelet, max_scale): - pywt.cwt(self.data, self.scales, wavelet) + def time_cwt(self, n, wavelet, max_scale, method): + try: + pywt.cwt(self.data, self.scales, wavelet, method=method) + except TypeError: + # older PyWavelets does not support use of the method argument + if method == 'fft': + raise NotImplementedError( + "fft-based convolution not available.") + pywt.cwt(self.data, self.scales, wavelet) diff --git a/doc/source/common_refs.rst b/doc/source/common_refs.rst index a8331b221..8fe04b957 100644 --- a/doc/source/common_refs.rst +++ b/doc/source/common_refs.rst @@ -5,7 +5,8 @@ .. _GitHub: https://github.com/PyWavelets/pywt .. _GitHub repository: https://github.com/PyWavelets/pywt .. _GitHub Issues: https://github.com/PyWavelets/pywt/issues -.. _Numpy: http://www.numpy.org +.. _NumPy: https://www.numpy.org +.. _SciPy: https://www.scipy.org .. _original developer: http://en.ig.ma .. _Python: http://python.org/ .. _Python Package Index: http://pypi.python.org/pypi/PyWavelets/ diff --git a/doc/source/install.rst b/doc/source/install.rst index 7a82d00aa..f880a8ea8 100644 --- a/doc/source/install.rst +++ b/doc/source/install.rst @@ -39,11 +39,12 @@ PyWavelets source code directory (containing ``setup.py``) and type:: The requirements needed to build from source are: - Python_ 2.7 or >=3.4 - - Numpy_ >= 1.13.3 + - NumPy_ >= 1.13.3 - Cython_ >= 0.23.5 (if installing from git, not from a PyPI source release) To run all the tests for PyWavelets, you will also need to install the -Matplotlib_ package. +Matplotlib_ package. If SciPy_ is available, FFT-based continuous wavelet +transforms will use the FFT implementation from SciPy instead of NumPy. .. seealso:: :ref:`Development guide ` section contains more information on building and installing from source code. diff --git a/pywt/_cwt.py b/pywt/_cwt.py index a4e6ca536..de6c4b24d 100644 --- a/pywt/_cwt.py +++ b/pywt/_cwt.py @@ -1,13 +1,40 @@ -import numpy as np +from math import floor, ceil from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet, Wavelet, _check_dtype) from ._functions import integrate_wavelet, scale2frequency + __all__ = ["cwt"] -def cwt(data, scales, wavelet, sampling_period=1.): +import numpy as np + +try: + # Prefer scipy.fft (new in SciPy 1.4) + import scipy.fft + fftmodule = scipy.fft + next_fast_len = fftmodule.next_fast_len +except ImportError: + try: + import scipy.fftpack + fftmodule = scipy.fftpack + next_fast_len = fftmodule.next_fast_len + except ImportError: + fftmodule = np.fft + + # provide a fallback so scipy is an optional requirement + def next_fast_len(n): + """Round up size to the nearest power of two. + + Given a number of samples `n`, returns the next power of two + following this number to take advantage of FFT speedup. + This fallback is less efficient than `scipy.fftpack.next_fast_len` + """ + return 2**ceil(np.log2(n)) + + +def cwt(data, scales, wavelet, sampling_period=1., method='conv'): """ cwt(data, scales, wavelet) @@ -29,6 +56,16 @@ def cwt(data, scales, wavelet, sampling_period=1.): The values computed for ``coefs`` are independent of the choice of ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling period). + method : {'conv', 'fft'}, optional + The method used to compute the CWT. Can be any of: + - ``conv`` uses ``numpy.convolve``. + - ``fft`` uses frequency domain convolution. + - ``auto`` uses automatic selection based on an estimate of the + computational complexity at each scale. + The ``conv`` method complexity is ``O(len(scale) * len(data))``. + The ``fft`` method is ``O(N * log2(N))`` with + ``N = len(scale) + len(data) - 1``. It is well suited for large size + signals but slightly slower than ``conv`` on small ones. Returns ------- @@ -74,34 +111,59 @@ def cwt(data, scales, wavelet, sampling_period=1.): wavelet = DiscreteContinuousWavelet(wavelet) if np.isscalar(scales): scales = np.array([scales]) + dt_out = None # TODO: fix in/out dtype consistency in a subsequent PR if data.ndim == 1: if wavelet.complex_cwt: - out = np.zeros((np.size(scales), data.size), dtype=complex) - else: - out = np.zeros((np.size(scales), data.size)) + dt_out = complex + out = np.empty((np.size(scales), data.size), dtype=dt_out) precision = 10 int_psi, x = integrate_wavelet(wavelet, precision=precision) - for i in np.arange(np.size(scales)): + + if method == 'fft': + size_scale0 = -1 + fft_data = None + elif not method == 'conv': + raise ValueError("method must be 'conv' or 'fft'") + + for i, scale in enumerate(scales): step = x[1] - x[0] - j = np.floor( - np.arange(scales[i] * (x[-1] - x[0]) + 1) / (scales[i] * step)) - if np.max(j) >= np.size(int_psi): - j = np.delete(j, np.where((j >= np.size(int_psi)))[0]) - coef = - np.sqrt(scales[i]) * np.diff( - np.convolve(data, int_psi[j.astype(np.int)][::-1])) + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + + if method == 'conv': + conv = np.convolve(data, int_psi_scale) + else: + # The padding is selected for: + # - optimal FFT complexity + # - to be larger than the two signals length to avoid circular + # convolution + size_scale = next_fast_len(data.size + int_psi_scale.size - 1) + if size_scale != size_scale0: + # Must recompute fft_data when the padding size changes. + fft_data = fftmodule.fft(data, size_scale) + size_scale0 = size_scale + fft_wav = fftmodule.fft(int_psi_scale, size_scale) + conv = fftmodule.ifft(fft_wav * fft_data) + conv = conv[:data.size + int_psi_scale.size - 1] + + coef = - np.sqrt(scale) * np.diff(conv) + if not np.iscomplexobj(out): + coef = np.real(coef) d = (coef.size - data.size) / 2. if d > 0: - out[i, :] = coef[int(np.floor(d)):int(-np.ceil(d))] + out[i, :] = coef[floor(d):-ceil(d)] elif d == 0.: out[i, :] = coef else: raise ValueError( - "Selected scale of {} too small.".format(scales[i])) + "Selected scale of {} too small.".format(scale)) frequencies = scale2frequency(wavelet, scales, precision) if np.isscalar(frequencies): frequencies = np.array([frequencies]) - for i in np.arange(len(frequencies)): - frequencies[i] /= sampling_period + frequencies /= sampling_period return out, frequencies else: raise ValueError("Only dim == 1 supported") diff --git a/pywt/tests/test_cwt_wavelets.py b/pywt/tests/test_cwt_wavelets.py index 4372efc6a..e2cafcb2a 100644 --- a/pywt/tests/test_cwt_wavelets.py +++ b/pywt/tests/test_cwt_wavelets.py @@ -371,3 +371,18 @@ def test_cwt_small_scales(): # extremely short scale factors raise a ValueError assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh') + + +def test_cwt_method_fft(): + rstate = np.random.RandomState(1) + data = rstate.randn(50) + data[15] = 1. + scales = np.arange(1, 64) + wavelet = 'cmor1.5-1.0' + + # build a reference cwt with the legacy np.conv() method + cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv') + + # compare with the fft based convolution + cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft') + assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)