Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ matrix:
- CYTHONSPEC=cython
- USE_WHEEL=1
- os: linux
python: 3.7-dev
python: 3.7
dist: xenial # travis-ci/travis-ci/issues/9815
sudo: true
env:
- NUMPYSPEC=numpy
- MATPLOTLIBSPEC=matplotlib
- CYTHONSPEC=cython
- USE_SDIST=1
- USE_SCIPY=1
- os: linux
python: 3.5
env:
Expand Down Expand Up @@ -59,6 +62,7 @@ before_install:
- pip install pytest pytest-cov coverage codecov futures
- set -o pipefail
- if [ "${USE_WHEEL}" == "1" ]; then pip install wheel; fi
- if [ "${USE_SCIPY}" == "1" ]; then pip install scipy; fi
- |
if [ "${REFGUIDE_CHECK}" == "1" ]; then
pip install sphinx numpydoc
Expand Down
9 changes: 6 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ For more usage examples see the `demo`_ directory in the source package.
Installation
------------

PyWavelets supports `Python`_ >=3.5, and is only dependent on `Numpy`_
PyWavelets supports `Python`_ >=3.5, and is only dependent on `NumPy`_
(supported versions are currently ``>= 1.13.3``). To pass all of the tests,
`Matplotlib`_ is also required.
`Matplotlib`_ is also required. `SciPy`_ is also an optional dependency. When
present, FFT-based continuous wavelet transforms will use FFTs from SciPy
rather than NumPy.

There are binary wheels for Intel Linux, Windows and macOS / OSX on PyPi. If
you are on one of these platforms, you should get a binary (precompiled)
Expand Down Expand Up @@ -138,7 +140,8 @@ If you wish to cite PyWavelets in a publication, you may use the following DOI.
.. _Anaconda: https://www.continuum.io
.. _GitHub: https://github.com/PyWavelets/pywt
.. _GitHub Issues: https://github.com/PyWavelets/pywt/issues
.. _Numpy: http://www.numpy.org
.. _NumPy: https://www.numpy.org
.. _SciPy: https://www.scipy.org
.. _original developer: http://en.ig.ma
.. _Python: http://python.org/
.. _Python Package Index: http://pypi.python.org/pypi/PyWavelets/
Expand Down
20 changes: 14 additions & 6 deletions benchmarks/benchmarks/cwt_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ class CwtTimeSuiteBase(object):
"""
Set-up for CWT timing.
"""
params = ([32, 128, 512],
params = ([32, 128, 512, 2048],
['cmor', 'cgau4', 'fbsp', 'gaus4', 'mexh', 'morl', 'shan'],
[16, 64, 256])
param_names = ('n', 'wavelet', 'max_scale')
[16, 64, 256],
['conv', 'fft'])
param_names = ('n', 'wavelet', 'max_scale', 'method')

def setup(self, n, wavelet, max_scale):
def setup(self, n, wavelet, max_scale, method):
try:
from pywt import cwt
except ImportError:
Expand All @@ -21,5 +22,12 @@ def setup(self, n, wavelet, max_scale):


class CwtTimeSuite(CwtTimeSuiteBase):
def time_cwt(self, n, wavelet, max_scale):
pywt.cwt(self.data, self.scales, wavelet)
def time_cwt(self, n, wavelet, max_scale, method):
try:
pywt.cwt(self.data, self.scales, wavelet, method=method)
except TypeError:
# older PyWavelets does not support use of the method argument
if method == 'fft':
raise NotImplementedError(
"fft-based convolution not available.")
pywt.cwt(self.data, self.scales, wavelet)
3 changes: 2 additions & 1 deletion doc/source/common_refs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
.. _GitHub: https://github.com/PyWavelets/pywt
.. _GitHub repository: https://github.com/PyWavelets/pywt
.. _GitHub Issues: https://github.com/PyWavelets/pywt/issues
.. _Numpy: http://www.numpy.org
.. _NumPy: https://www.numpy.org
.. _SciPy: https://www.scipy.org
.. _original developer: http://en.ig.ma
.. _Python: http://python.org/
.. _Python Package Index: http://pypi.python.org/pypi/PyWavelets/
Expand Down
5 changes: 3 additions & 2 deletions doc/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ PyWavelets source code directory (containing ``setup.py``) and type::
The requirements needed to build from source are:

- Python_ 2.7 or >=3.4
- Numpy_ >= 1.13.3
- NumPy_ >= 1.13.3
- Cython_ >= 0.23.5 (if installing from git, not from a PyPI source release)

To run all the tests for PyWavelets, you will also need to install the
Matplotlib_ package.
Matplotlib_ package. If SciPy_ is available, FFT-based continuous wavelet
transforms will use the FFT implementation from SciPy instead of NumPy.

.. seealso:: :ref:`Development guide <dev-index>` section contains more
information on building and installing from source code.
Expand Down
94 changes: 78 additions & 16 deletions pywt/_cwt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
import numpy as np
from math import floor, ceil

from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
Wavelet, _check_dtype)
from ._functions import integrate_wavelet, scale2frequency


__all__ = ["cwt"]


def cwt(data, scales, wavelet, sampling_period=1.):
import numpy as np

try:
# Prefer scipy.fft (new in SciPy 1.4)
import scipy.fft
Comment thread
grlee77 marked this conversation as resolved.
fftmodule = scipy.fft
next_fast_len = fftmodule.next_fast_len
except ImportError:
try:
import scipy.fftpack
fftmodule = scipy.fftpack
next_fast_len = fftmodule.next_fast_len
except ImportError:
fftmodule = np.fft

# provide a fallback so scipy is an optional requirement
def next_fast_len(n):
"""Round up size to the nearest power of two.

Given a number of samples `n`, returns the next power of two
following this number to take advantage of FFT speedup.
This fallback is less efficient than `scipy.fftpack.next_fast_len`
"""
return 2**ceil(np.log2(n))


def cwt(data, scales, wavelet, sampling_period=1., method='conv'):
"""
cwt(data, scales, wavelet)

Expand All @@ -29,6 +56,16 @@ def cwt(data, scales, wavelet, sampling_period=1.):
The values computed for ``coefs`` are independent of the choice of
``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
period).
method : {'conv', 'fft'}, optional
The method used to compute the CWT. Can be any of:
- ``conv`` uses ``numpy.convolve``.
- ``fft`` uses frequency domain convolution.
- ``auto`` uses automatic selection based on an estimate of the
computational complexity at each scale.
The ``conv`` method complexity is ``O(len(scale) * len(data))``.
The ``fft`` method is ``O(N * log2(N))`` with
``N = len(scale) + len(data) - 1``. It is well suited for large size
signals but slightly slower than ``conv`` on small ones.

Returns
-------
Expand Down Expand Up @@ -74,34 +111,59 @@ def cwt(data, scales, wavelet, sampling_period=1.):
wavelet = DiscreteContinuousWavelet(wavelet)
if np.isscalar(scales):
scales = np.array([scales])
dt_out = None # TODO: fix in/out dtype consistency in a subsequent PR
if data.ndim == 1:
if wavelet.complex_cwt:
out = np.zeros((np.size(scales), data.size), dtype=complex)
else:
out = np.zeros((np.size(scales), data.size))
dt_out = complex
out = np.empty((np.size(scales), data.size), dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)
for i in np.arange(np.size(scales)):

if method == 'fft':
size_scale0 = -1
fft_data = None
elif not method == 'conv':
raise ValueError("method must be 'conv' or 'fft'")

for i, scale in enumerate(scales):
step = x[1] - x[0]
j = np.floor(
np.arange(scales[i] * (x[-1] - x[0]) + 1) / (scales[i] * step))
if np.max(j) >= np.size(int_psi):
j = np.delete(j, np.where((j >= np.size(int_psi)))[0])
coef = - np.sqrt(scales[i]) * np.diff(
np.convolve(data, int_psi[j.astype(np.int)][::-1]))
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
j = j.astype(int) # floor
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]

if method == 'conv':
conv = np.convolve(data, int_psi_scale)
else:
# The padding is selected for:
# - optimal FFT complexity
# - to be larger than the two signals length to avoid circular
# convolution
size_scale = next_fast_len(data.size + int_psi_scale.size - 1)
if size_scale != size_scale0:
# Must recompute fft_data when the padding size changes.
fft_data = fftmodule.fft(data, size_scale)
size_scale0 = size_scale
fft_wav = fftmodule.fft(int_psi_scale, size_scale)
conv = fftmodule.ifft(fft_wav * fft_data)
conv = conv[:data.size + int_psi_scale.size - 1]

coef = - np.sqrt(scale) * np.diff(conv)
if not np.iscomplexobj(out):
coef = np.real(coef)
d = (coef.size - data.size) / 2.
if d > 0:
out[i, :] = coef[int(np.floor(d)):int(-np.ceil(d))]
out[i, :] = coef[floor(d):-ceil(d)]
elif d == 0.:
out[i, :] = coef
else:
raise ValueError(
"Selected scale of {} too small.".format(scales[i]))
"Selected scale of {} too small.".format(scale))
frequencies = scale2frequency(wavelet, scales, precision)
if np.isscalar(frequencies):
frequencies = np.array([frequencies])
for i in np.arange(len(frequencies)):
frequencies[i] /= sampling_period
frequencies /= sampling_period
return out, frequencies
else:
raise ValueError("Only dim == 1 supported")
15 changes: 15 additions & 0 deletions pywt/tests/test_cwt_wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,18 @@ def test_cwt_small_scales():

# extremely short scale factors raise a ValueError
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')


def test_cwt_method_fft():
Comment thread
alsauve marked this conversation as resolved.
rstate = np.random.RandomState(1)
data = rstate.randn(50)
data[15] = 1.
scales = np.arange(1, 64)
wavelet = 'cmor1.5-1.0'

# build a reference cwt with the legacy np.conv() method
cfs_conv, _ = pywt.cwt(data, scales, wavelet, method='conv')

# compare with the fft based convolution
cfs_fft, _ = pywt.cwt(data, scales, wavelet, method='fft')
assert_allclose(cfs_conv, cfs_fft, rtol=0, atol=1e-13)