diff --git a/.coveragerc b/.coveragerc
index fbc396796..b6fef2d7e 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -5,6 +5,7 @@ include = */pywt/*
omit =
*/version.py
*/pywt/tests/*
+ */pywt/_doc_utils.py*
*/pywt/data/create_dat.py
*.pxd
stringsource
diff --git a/.travis.yml b/.travis.yml
index d95198efc..0965345f5 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -56,7 +56,7 @@ before_install:
- pip install $NUMPYSPEC
- pip install $MATPLOTLIBSPEC
- pip install $CYTHONSPEC
- - pip install nose coverage codecov futures
+ - pip install pytest pytest-cov coverage codecov futures
- set -o pipefail
- if [ "${USE_WHEEL}" == "1" ]; then pip install wheel; fi
- |
@@ -72,7 +72,7 @@ script:
pip wheel . -v
pip install PyWavelets*.whl -v
pushd demo
- nosetests pywt
+ pytest --pyargs pywt
python ../pywt/tests/test_doc.py
popd
elif [ "${USE_SDIST}" == "1" ]; then
@@ -80,7 +80,7 @@ script:
# Move out of source directory to avoid finding local pywt
pushd dist
pip install PyWavelets* -v
- nosetests pywt
+ pytest --pyargs pywt
python ../pywt/tests/test_doc.py
popd
elif [ "${REFGUIDE_CHECK}" == "1" ]; then
@@ -88,7 +88,10 @@ script:
python util/refguide_check.py --doctests
else
CFLAGS="--coverage" python setup.py build --build-lib build/lib/ --build-temp build/tmp/
- nosetests build/lib/ --tests pywt/tests
+ CFLAGS="--coverage" pip install -e . -v
+ pushd demo
+ pytest --pyargs pywt --cov=pywt
+ popd
fi
after_success:
diff --git a/LICENSE b/LICENSE
index 60a926c44..47b60f4b6 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,5 +1,5 @@
Copyright (c) 2006-2012 Filip Wasilewski
-Copyright (c) 2012-2017 The PyWavelets Developers
+Copyright (c) 2012-2019 The PyWavelets Developers
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
@@ -18,3 +18,15 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+
+
+The PyWavelets repository and source distributions bundle some code that is
+adapted from compatibly licensed projects. We list these here.
+
+Name: NumPy
+Files: pywt/_pytesttester.py
+License: 3-clause BSD
+
+Name: SciPy
+Files: setup.py, util/*
+License: 3-clause BSD
diff --git a/appveyor.yml b/appveyor.yml
index 137296493..4d508e345 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -23,11 +23,13 @@ install:
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install
numpy --cache-dir c:\\tmp\\pip-cache"
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install
- Cython nose coverage matplotlib futures --cache-dir c:\\tmp\\pip-cache"
+ Cython pytest coverage matplotlib futures --cache-dir c:\\tmp\\pip-cache"
test_script:
- - "util\\appveyor\\build.cmd %PYTHON%\\python.exe setup.py build --build-lib build\\lib\\"
- - "%PYTHON%\\Scripts\\nosetests build\\lib --tests pywt\\tests"
+ - "util\\appveyor\\build.cmd %PYTHON%\\python.exe -m pip install -e . -v"
+ - "cd demo"
+ - "%PYTHON%\\Scripts\\pytest --pyargs pywt"
+ - "cd .."
after_test:
- "util\\appveyor\\build.cmd %PYTHON%\\python.exe setup.py bdist_wheel"
diff --git a/doc/source/dev/testing.rst b/doc/source/dev/testing.rst
index 532f7327f..12b090634 100644
--- a/doc/source/dev/testing.rst
+++ b/doc/source/dev/testing.rst
@@ -23,13 +23,22 @@ does not break the build.
Running tests locally
---------------------
-Tests are implemented with `nose`_, so use one of:
+Tests are implemented with `pytest`_, so use one of:
- $ nosetests pywt
+ $ pytest --pyargs pywt -v
- >>> pywt.test() # doctest: +SKIP
+There are also older doctests that can be run by performing the following from
+the root of the project source.
-Note doctests require `Matplotlib`_ in addition to the usual dependencies.
+ $ python pywt/tests/test_doc.py
+ $ cd doc
+ $ make doctest
+
+Additionally the examples in the demo subfolder can be checked by running:
+
+ $ python util/refguide_check.py
+
+Note: doctests require `Matplotlib`_ in addition to the usual dependencies.
Running tests with Tox
@@ -43,6 +52,6 @@ To for example run tests for Python 3.5 and 3.6 use::
For more information see the `Tox`_ documentation.
-.. _nose: http://nose.readthedocs.org/en/latest/
-.. _Tox: http://tox.testrun.org/
-.. _Matplotlib: http://matplotlib.org
+.. _pytest: https://pytest.org
+.. _Tox: https://tox.readthedocs.io/en/latest/
+.. _Matplotlib: https://matplotlib.org
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 000000000..7c6a703d9
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,15 @@
+[pytest]
+addopts = -l
+norecursedirs = doc tools pywt/_extensions
+doctest_optionflags = NORMALIZE_WHITESPACE ELLIPSIS ALLOW_UNICODE ALLOW_BYTES
+
+filterwarnings =
+ error
+# Filter out annoying import messages.
+ ignore:Not importing directory
+ ignore:numpy.dtype size changed
+ ignore:numpy.ufunc size changed
+ ignore::UserWarning:cpuinfo,
+
+env =
+ PYTHONHASHSEED=0
diff --git a/pywt/__init__.py b/pywt/__init__.py
index 13f196f01..618bd55bc 100644
--- a/pywt/__init__.py
+++ b/pywt/__init__.py
@@ -35,11 +35,6 @@
from pywt.version import version as __version__
-import numpy as np
-if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
- from ._utils import is_nose_running
- if is_nose_running():
- np.set_printoptions(legacy='1.13')
-
-from numpy.testing import Tester
-test = Tester().test
+from ._pytesttester import PytestTester
+test = PytestTester(__name__)
+del PytestTester
diff --git a/pywt/_pytest.py b/pywt/_pytest.py
new file mode 100644
index 000000000..cfc9f0590
--- /dev/null
+++ b/pywt/_pytest.py
@@ -0,0 +1,68 @@
+"""common test-related code."""
+import os
+import sys
+import multiprocessing
+import numpy as np
+import pytest
+
+
+__all__ = ['uses_matlab', # skip if pymatbridge and Matlab unavailable
+ 'uses_futures', # skip if futures unavailable
+ 'uses_pymatbridge', # skip if no PYWT_XSLOW environment variable
+ 'uses_precomputed', # skip if PYWT_XSLOW environment variable found
+ 'matlab_result_dict_cwt', # dict with precomputed Matlab dwt data
+ 'matlab_result_dict_dwt', # dict with precomputed Matlab cwt data
+ 'futures', # the futures module or None
+ 'max_workers', # the number of workers available to futures
+ 'size_set', # the set of Matlab tests to run
+ ]
+
+try:
+ if sys.version_info[0] == 2:
+ import futures
+ else:
+ from concurrent import futures
+ max_workers = multiprocessing.cpu_count()
+ futures_available = True
+except ImportError:
+ futures_available = False
+ futures = None
+
+# check if pymatbridge + MATLAB tests should be run
+matlab_result_dict_dwt = None
+matlab_result_dict_cwt = None
+matlab_missing = True
+use_precomputed = True
+size_set = 'reduced'
+if 'PYWT_XSLOW' in os.environ:
+ try:
+ from pymatbridge import Matlab
+ mlab = Matlab()
+ matlab_missing = False
+ use_precomputed = False
+ size_set = 'full'
+ except ImportError:
+ print("To run Matlab compatibility tests you need to have MathWorks "
+ "MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
+ "package installed.")
+if use_precomputed:
+ # load dictionaries of precomputed results
+ data_dir = os.path.join(os.path.dirname(__file__), 'tests', 'data')
+ matlab_data_file_cwt = os.path.join(
+ data_dir, 'cwt_matlabR2015b_result.npz')
+ matlab_result_dict_cwt = np.load(matlab_data_file_cwt)
+
+ matlab_data_file_dwt = os.path.join(
+ data_dir, 'dwt_matlabR2012a_result.npz')
+ matlab_result_dict_dwt = np.load(matlab_data_file_dwt)
+
+uses_futures = pytest.mark.skipif(
+ not futures_available, reason='futures not available')
+uses_matlab = pytest.mark.skipif(
+ matlab_missing, reason='pymatbridge and/or Matlab not available')
+uses_pymatbridge = pytest.mark.skipif(
+ use_precomputed,
+ reason='PYWT_XSLOW set: skipping tests against precomputed Matlab results')
+uses_precomputed = pytest.mark.skipif(
+ not use_precomputed,
+ reason='PYWT_XSLOW not set: test against precomputed matlab tests')
diff --git a/pywt/_pytesttester.py b/pywt/_pytesttester.py
new file mode 100644
index 000000000..426047a71
--- /dev/null
+++ b/pywt/_pytesttester.py
@@ -0,0 +1,164 @@
+"""
+Pytest test running.
+
+This module implements the ``test()`` function for NumPy modules. The usual
+boiler plate for doing that is to put the following in the module
+``__init__.py`` file::
+
+ from pywt._pytesttester import PytestTester
+ test = PytestTester(__name__).test
+ del PytestTester
+
+
+Warnings filtering and other runtime settings should be dealt with in the
+``pytest.ini`` file in the pywt repo root. The behavior of the test depends on
+whether or not that file is found as follows:
+
+* ``pytest.ini`` is present (develop mode)
+ All warnings except those explicily filtered out are raised as error.
+* ``pytest.ini`` is absent (release mode)
+ DeprecationWarnings and PendingDeprecationWarnings are ignored, other
+ warnings are passed through.
+
+In practice, tests run from the PyWavelets repo are run in develop mode. That
+includes the standard ``python runtests.py`` invocation.
+
+"""
+from __future__ import division, absolute_import, print_function
+
+import sys
+import os
+
+__all__ = ['PytestTester']
+
+
+def _show_pywt_info():
+ import pywt
+ from pywt._c99_config import _have_c99_complex
+ print("PyWavelets version %s" % pywt.__version__)
+ if _have_c99_complex:
+ print("Compiled with C99 complex support.")
+ else:
+ print("Compiled without C99 complex support.")
+
+
+class PytestTester(object):
+ """
+ Pytest test runner.
+
+ This class is made available in ``pywt.testing``, and a test function
+ is typically added to a package's __init__.py like so::
+
+ from pywt.testing import PytestTester
+ test = PytestTester(__name__).test
+ del PytestTester
+
+ Calling this test function finds and runs all tests associated with the
+ module and all its sub-modules.
+
+ Attributes
+ ----------
+ module_name : str
+ Full path to the package to test.
+
+ Parameters
+ ----------
+ module_name : module name
+ The name of the module to test.
+
+ """
+ def __init__(self, module_name):
+ self.module_name = module_name
+
+ def __call__(self, label='fast', verbose=1, extra_argv=None,
+ doctests=False, coverage=False, durations=-1, tests=None):
+ """
+ Run tests for module using pytest.
+
+ Parameters
+ ----------
+ label : {'fast', 'full'}, optional
+ Identifies the tests to run. When set to 'fast', tests decorated
+ with `pytest.mark.slow` are skipped, when 'full', the slow marker
+ is ignored.
+ verbose : int, optional
+ Verbosity value for test outputs, in the range 1-3. Default is 1.
+ extra_argv : list, optional
+ List with any extra arguments to pass to pytests.
+ doctests : bool, optional
+ .. note:: Not supported
+ coverage : bool, optional
+ If True, report coverage of NumPy code. Default is False.
+ Requires installation of (pip) pytest-cov.
+ durations : int, optional
+ If < 0, do nothing, If 0, report time of all tests, if > 0,
+ report the time of the slowest `timer` tests. Default is -1.
+ tests : test or list of tests
+ Tests to be executed with pytest '--pyargs'
+
+ Returns
+ -------
+ result : bool
+ Return True on success, false otherwise.
+
+ Examples
+ --------
+ >>> result = np.lib.test() #doctest: +SKIP
+ ...
+ 1023 passed, 2 skipped, 6 deselected, 1 xfailed in 10.39 seconds
+ >>> result
+ True
+
+ """
+ import pytest
+
+ module = sys.modules[self.module_name]
+ module_path = os.path.abspath(module.__path__[0])
+
+ # setup the pytest arguments
+ pytest_args = ["-l"]
+
+ # offset verbosity. The "-q" cancels a "-v".
+ pytest_args += ["-q"]
+
+ # Filter out annoying import messages. Want these in both develop and
+ # release mode.
+ pytest_args += [
+ "-W ignore:Not importing directory",
+ "-W ignore:numpy.dtype size changed",
+ "-W ignore:numpy.ufunc size changed", ]
+
+ if doctests:
+ raise ValueError("Doctests not supported")
+
+ if extra_argv:
+ pytest_args += list(extra_argv)
+
+ if verbose > 1:
+ pytest_args += ["-" + "v"*(verbose - 1)]
+
+ if coverage:
+ pytest_args += ["--cov=" + module_path]
+
+ if label == "fast":
+ pytest_args += ["-m", "not slow"]
+ elif label != "full":
+ pytest_args += ["-m", label]
+
+ if durations >= 0:
+ pytest_args += ["--durations=%s" % durations]
+
+ if tests is None:
+ tests = [self.module_name]
+
+ pytest_args += ["--pyargs"] + list(tests)
+
+ # run tests.
+ _show_pywt_info()
+
+ try:
+ code = pytest.main(pytest_args)
+ except SystemExit as exc:
+ code = exc.code
+
+ return code == 0
diff --git a/pywt/_utils.py b/pywt/_utils.py
index 9ba45b104..48f814e21 100644
--- a/pywt/_utils.py
+++ b/pywt/_utils.py
@@ -98,23 +98,3 @@ def _modes_per_axis(modes, axes):
else:
raise ValueError("modes must be a str, Mode enum or iterable")
return modes
-
-
-def is_nose_running():
- """Returns whether we are running the nose test loader
- """
- if 'nose' not in sys.modules:
- return False
- try:
- import nose
- except ImportError:
- return False
- # Now check that we have the loader in the call stask
- stack = inspect.stack()
- loader_file_name = nose.loader.__file__
- if loader_file_name.endswith('.pyc'):
- loader_file_name = loader_file_name[:-1]
- for _, file_name, _, _, _, _ in stack:
- if file_name == loader_file_name:
- return True
- return False
diff --git a/pywt/tests/test__pywt.py b/pywt/tests/test__pywt.py
index 594125e3d..d17c7582c 100644
--- a/pywt/tests/test__pywt.py
+++ b/pywt/tests/test__pywt.py
@@ -3,8 +3,7 @@
from __future__ import division, print_function, absolute_import
import numpy as np
-from numpy.testing import (run_module_suite, assert_allclose, assert_,
- assert_raises)
+from numpy.testing import assert_allclose, assert_, assert_raises
import pywt
@@ -169,7 +168,3 @@ def test_wavelet_errormsgs():
pywt.Wavelet('cmord')
except ValueError as e:
assert_(e.args[0] == "Invalid wavelet name 'cmord'.")
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_concurrent.py b/pywt/tests/test_concurrent.py
index ebc739ab1..041171fd8 100644
--- a/pywt/tests/test_concurrent.py
+++ b/pywt/tests/test_concurrent.py
@@ -5,26 +5,14 @@
from __future__ import division, print_function, absolute_import
-import sys
import warnings
-import multiprocessing
import numpy as np
from functools import partial
-from numpy.testing import (dec, run_module_suite, assert_array_equal,
- assert_allclose)
+from numpy.testing import assert_array_equal, assert_allclose
+from pywt._pytest import uses_futures, futures, max_workers
import pywt
-try:
- if sys.version_info[0] == 2:
- import futures
- else:
- from concurrent import futures
- max_workers = multiprocessing.cpu_count()
- futures_available = True
-except ImportError:
- futures_available = False
-
def _assert_all_coeffs_equal(coefs1, coefs2):
# return True only if all coefficients of SWT or DWT match over all levels
@@ -44,7 +32,7 @@ def _assert_all_coeffs_equal(coefs1, coefs2):
return True
-@dec.skipif(not futures_available)
+@uses_futures
def test_concurrent_swt():
# tests error-free concurrent operation (see gh-288)
# swt on 1D data calls the Cython swt
@@ -65,7 +53,7 @@ def test_concurrent_swt():
_assert_all_coeffs_equal(expected_result, results[-1])
-@dec.skipif(not futures_available)
+@uses_futures
def test_concurrent_wavedec():
# wavedec on 1D data calls the Cython dwt_single
# other cases call dwt_axis
@@ -82,7 +70,7 @@ def test_concurrent_wavedec():
_assert_all_coeffs_equal(expected_result, results[-1])
-@dec.skipif(not futures_available)
+@uses_futures
def test_concurrent_dwt():
# dwt on 1D data calls the Cython dwt_single
# other cases call dwt_axis
@@ -99,7 +87,7 @@ def test_concurrent_dwt():
_assert_all_coeffs_equal([expected_result, ], [results[-1], ])
-@dec.skipif(not futures_available)
+@uses_futures
def test_concurrent_cwt():
atol = rtol = 1e-14
time, sst = pywt.data.nino()
@@ -115,7 +103,3 @@ def test_concurrent_cwt():
expected_result = transform(sst)
for a1, a2 in zip(expected_result, results[-1]):
assert_allclose(a1, a2, atol=atol, rtol=rtol)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_cwt_wavelets.py b/pywt/tests/test_cwt_wavelets.py
index 017c9e081..4372efc6a 100644
--- a/pywt/tests/test_cwt_wavelets.py
+++ b/pywt/tests/test_cwt_wavelets.py
@@ -1,8 +1,8 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
-from numpy.testing import (run_module_suite, assert_allclose, assert_warns,
- assert_almost_equal, assert_raises)
+from numpy.testing import (assert_allclose, assert_warns, assert_almost_equal,
+ assert_raises)
import numpy as np
import pywt
@@ -371,6 +371,3 @@ def test_cwt_small_scales():
# extremely short scale factors raise a ValueError
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_data.py b/pywt/tests/test_data.py
index 36bef07a1..6415c1c2e 100644
--- a/pywt/tests/test_data.py
+++ b/pywt/tests/test_data.py
@@ -1,7 +1,6 @@
import os
import numpy as np
-from numpy.testing import (assert_allclose, assert_raises, assert_,
- run_module_suite)
+from numpy.testing import assert_allclose, assert_raises, assert_
import pywt.data
@@ -76,7 +75,3 @@ def test_wavelab_signals():
# ValueError on invalid length
assert_raises(ValueError, pywt.data.demo_signal, 'Doppler', 0)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_deprecations.py b/pywt/tests/test_deprecations.py
index b4868efa6..5a6cbad9e 100644
--- a/pywt/tests/test_deprecations.py
+++ b/pywt/tests/test_deprecations.py
@@ -1,7 +1,7 @@
import warnings
import numpy as np
-from numpy.testing import assert_warns, run_module_suite, assert_array_equal
+from numpy.testing import assert_warns, assert_array_equal
import pywt
@@ -87,7 +87,3 @@ def test_mode_equivalence():
for old, new in old_new:
assert_array_equal(pywt.dwt(x, 'db2', mode=old),
pywt.dwt(x, 'db2', mode=new))
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_dwt_idwt.py b/pywt/tests/test_dwt_idwt.py
index 4d0de63a4..1fd17e042 100644
--- a/pywt/tests/test_dwt_idwt.py
+++ b/pywt/tests/test_dwt_idwt.py
@@ -2,8 +2,7 @@
from __future__ import division, print_function, absolute_import
import numpy as np
-from numpy.testing import (run_module_suite, assert_allclose, assert_,
- assert_raises)
+from numpy.testing import assert_allclose, assert_, assert_raises
import pywt
@@ -224,7 +223,3 @@ def test_error_on_continuous_wavelet():
cA, cD = pywt.dwt(data, 'db1')
assert_raises(ValueError, pywt.idwt, cA, cD, cwave)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_functions.py b/pywt/tests/test_functions.py
index f60905b65..3f4a46973 100644
--- a/pywt/tests/test_functions.py
+++ b/pywt/tests/test_functions.py
@@ -1,8 +1,7 @@
#!/usr/bin/env python
from __future__ import division, print_function, absolute_import
-from numpy.testing import (run_module_suite, assert_almost_equal,
- assert_allclose)
+from numpy.testing import assert_almost_equal, assert_allclose
import pywt
@@ -37,7 +36,3 @@ def test_intwave_orthogonal():
# For x > 0.5, the integral is equal to (1 - x)
# Ignore last point here, there x > 1 and something goes wrong
assert_allclose(int_psi[~ix][:-1], 1 - x[~ix][:-1], atol=1e-10)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_matlab_compatibility.py b/pywt/tests/test_matlab_compatibility.py
index b28c56e6b..58c366aed 100644
--- a/pywt/tests/test_matlab_compatibility.py
+++ b/pywt/tests/test_matlab_compatibility.py
@@ -5,35 +5,13 @@
from __future__ import division, print_function, absolute_import
-import os
import numpy as np
-from numpy.testing import assert_, dec, run_module_suite
+import pytest
+from numpy.testing import assert_
import pywt
-
-if 'PYWT_XSLOW' in os.environ:
- # Run a more comprehensive set of problem sizes. This could take more than
- # an hour to complete.
- size_set = 'full'
- use_precomputed = False
-else:
- size_set = 'reduced'
- use_precomputed = True
-
-if use_precomputed:
- data_dir = os.path.join(os.path.dirname(__file__), 'data')
- matlab_data_file = os.path.join(data_dir, 'dwt_matlabR2012a_result.npz')
- matlab_result_dict = np.load(matlab_data_file)
-else:
- try:
- from pymatbridge import Matlab
- mlab = Matlab()
- _matlab_missing = False
- except ImportError:
- print("To run Matlab compatibility tests you need to have MathWorks "
- "MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
- "package installed.")
- _matlab_missing = True
+from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set)
+from pywt._pytest import matlab_result_dict_dwt as matlab_result_dict
# list of mode names in pywt and matlab
modes = [('zero', 'zpd'),
@@ -63,9 +41,12 @@ def _get_data_sizes(w):
return data_sizes
-@dec.skipif(use_precomputed or _matlab_missing)
-@dec.slow
+@uses_pymatbridge
+@pytest.mark.slow
def test_accuracy_pymatbridge():
+ Matlab = pytest.importorskip("pymatbridge.Matlab")
+ mlab = Matlab()
+
rstate = np.random.RandomState(1234)
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
epsilon = 5.0e-5
@@ -79,17 +60,17 @@ def test_accuracy_pymatbridge():
data = rstate.randn(N)
mlab.set_variable('data', data)
for pmode, mmode in modes:
- ma, md = _compute_matlab_result(data, wavelet, mmode)
- yield _check_accuracy, data, w, pmode, ma, md, wavelet, epsilon
+ ma, md = _compute_matlab_result(data, wavelet, mmode, mlab)
+ _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
- yield _check_accuracy, data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs
+ _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
finally:
mlab.stop()
-@dec.skipif(not use_precomputed)
-@dec.slow
+@uses_precomputed
+@pytest.mark.slow
def test_accuracy_precomputed():
# Keep this specific random seed to match the precomputed Matlab result.
rstate = np.random.RandomState(1234)
@@ -102,12 +83,12 @@ def test_accuracy_precomputed():
data = rstate.randn(N)
for pmode, mmode in modes:
ma, md = _load_matlab_result(data, wavelet, mmode)
- yield _check_accuracy, data, w, pmode, ma, md, wavelet, epsilon
+ _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon)
ma, md = _load_matlab_result_pywt_coeffs(data, wavelet, mmode)
- yield _check_accuracy, data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs
+ _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon_pywt_coeffs)
-def _compute_matlab_result(data, wavelet, mmode):
+def _compute_matlab_result(data, wavelet, mmode, mlab):
""" Compute the result using MATLAB.
This function assumes that the Matlab variables `wavelet` and `data` have
@@ -177,7 +158,3 @@ def _check_accuracy(data, w, pmode, ma, md, wavelet, epsilon):
msg = ('[RMS_D > EPSILON] for Mode: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_d))
assert_(rms_d < epsilon, msg=msg)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_matlab_compatibility_cwt.py b/pywt/tests/test_matlab_compatibility_cwt.py
index 06ee180c4..9dc9e35bb 100644
--- a/pywt/tests/test_matlab_compatibility_cwt.py
+++ b/pywt/tests/test_matlab_compatibility_cwt.py
@@ -6,35 +6,13 @@
from __future__ import division, print_function, absolute_import
import warnings
-import os
import numpy as np
-from numpy.testing import assert_, dec, run_module_suite
+import pytest
+from numpy.testing import assert_
import pywt
-
-if 'PYWT_XSLOW' in os.environ:
- # Run a more comprehensive set of problem sizes. This could take more than
- # an hour to complete.
- size_set = 'full'
- use_precomputed = False
-else:
- size_set = 'reduced'
- use_precomputed = True
-
-if use_precomputed:
- data_dir = os.path.join(os.path.dirname(__file__), 'data')
- matlab_data_file = os.path.join(data_dir, 'cwt_matlabR2015b_result.npz')
- matlab_result_dict = np.load(matlab_data_file)
-else:
- try:
- from pymatbridge import Matlab
- mlab = Matlab()
- _matlab_missing = False
- except ImportError:
- print("To run Matlab compatibility tests you need to have MathWorks "
- "MATLAB, MathWorks Wavelet Toolbox and the pymatbridge Python "
- "package installed.")
- _matlab_missing = True
+from pywt._pytest import (uses_pymatbridge, uses_precomputed, size_set,
+ matlab_result_dict_cwt)
families = ('gaus', 'mexh', 'morl', 'cgau', 'shan', 'fbsp', 'cmor')
wavelets = sum([pywt.wavelist(name) for name in families], [])
@@ -53,15 +31,17 @@ def _get_data_sizes(w):
def _get_scales(w):
""" Return the scales to test for wavelet w. """
if size_set == 'full':
- Scales = (1,np.arange(1,3),np.arange(1,4),np.arange(1,5))
+ scales = (1, np.arange(1, 3), np.arange(1, 4), np.arange(1, 5))
else:
- Scales = (1,np.arange(1,3))
- return Scales
+ scales = (1, np.arange(1, 3))
+ return scales
-@dec.skipif(use_precomputed or _matlab_missing)
-@dec.slow
+@uses_pymatbridge # skip this case if precomputed results are used instead
+@pytest.mark.slow
def test_accuracy_pymatbridge_cwt():
+ Matlab = pytest.importorskip("pymatbridge.Matlab")
+ mlab = Matlab()
rstate = np.random.RandomState(1234)
# max RMSE (was 1.0e-10, is reduced to 5.0e-5 due to different coefficents)
epsilon = 1e-15
@@ -81,20 +61,20 @@ def test_accuracy_pymatbridge_cwt():
mlab_code = ("psi = wavefun(wavelet,10)")
res = mlab.run_code(mlab_code)
psi = np.asarray(mlab.get_variable('psi'))
- yield _check_accuracy_psi, w, psi, wavelet, epsilon_psi
+ _check_accuracy_psi(w, psi, wavelet, epsilon_psi)
for N in _get_data_sizes(w):
data = rstate.randn(N)
mlab.set_variable('data', data)
for scales in _get_scales(w):
- coefs = _compute_matlab_result(data, wavelet, scales)
- yield _check_accuracy, data, w, scales, coefs, wavelet, epsilon
+ coefs = _compute_matlab_result(data, wavelet, scales, mlab)
+ _check_accuracy(data, w, scales, coefs, wavelet, epsilon)
finally:
mlab.stop()
-@dec.skipif(not use_precomputed)
-@dec.slow
+@uses_precomputed # skip this case if pymatbridge + Matlab are being used
+@pytest.mark.slow
def test_accuracy_precomputed_cwt():
# Keep this specific random seed to match the precomputed Matlab result.
rstate = np.random.RandomState(1234)
@@ -108,7 +88,7 @@ def test_accuracy_precomputed_cwt():
w = pywt.ContinuousWavelet(wavelet)
w32 = pywt.ContinuousWavelet(wavelet,dtype=np.float32)
psi = _load_matlab_result_psi(wavelet)
- yield _check_accuracy_psi, w, psi, wavelet, epsilon_psi
+ _check_accuracy_psi(w, psi, wavelet, epsilon_psi)
for N in _get_data_sizes(w):
data = rstate.randn(N)
@@ -117,11 +97,11 @@ def test_accuracy_precomputed_cwt():
for scales in _get_scales(w):
scales_count += 1
coefs = _load_matlab_result(data, wavelet, scales_count)
- yield _check_accuracy, data, w, scales, coefs, wavelet, epsilon
- yield _check_accuracy, data32, w32, scales, coefs, wavelet, epsilon32
+ _check_accuracy(data, w, scales, coefs, wavelet, epsilon)
+ _check_accuracy(data32, w32, scales, coefs, wavelet, epsilon32)
-def _compute_matlab_result(data, wavelet, scales):
+def _compute_matlab_result(data, wavelet, scales, mlab):
""" Compute the result using MATLAB.
This function assumes that the Matlab variables `wavelet` and `data` have
@@ -143,11 +123,11 @@ def _load_matlab_result(data, wavelet, scales):
"""
N = len(data)
coefs_key = '_'.join([str(scales), wavelet, str(N), 'coefs'])
- if (coefs_key not in matlab_result_dict):
+ if (coefs_key not in matlab_result_dict_cwt):
raise KeyError(
"Precompted Matlab result not found for wavelet: "
"{0}, mode: {1}, size: {2}".format(wavelet, scales, N))
- coefs = matlab_result_dict[coefs_key]
+ coefs = matlab_result_dict_cwt[coefs_key]
return coefs
@@ -155,11 +135,11 @@ def _load_matlab_result_psi(wavelet):
""" Load the precomputed result.
"""
psi_key = '_'.join([wavelet, 'psi'])
- if (psi_key not in matlab_result_dict):
+ if (psi_key not in matlab_result_dict_cwt):
raise KeyError(
"Precompted Matlab psi result not found for wavelet: "
"{0}}".format(wavelet))
- psi = matlab_result_dict[psi_key]
+ psi = matlab_result_dict_cwt[psi_key]
return psi
@@ -187,6 +167,3 @@ def _check_accuracy_psi(w, psi, wavelet, epsilon):
msg = ('[RMS > EPSILON] for Wavelet: %s, '
'rms=%.3g' % (wavelet, rms))
assert_(rms < epsilon, msg=msg)
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_modes.py b/pywt/tests/test_modes.py
index 80a490b7d..31ea95321 100644
--- a/pywt/tests/test_modes.py
+++ b/pywt/tests/test_modes.py
@@ -2,8 +2,7 @@
from __future__ import division, print_function, absolute_import
import numpy as np
-from numpy.testing import (assert_raises, run_module_suite,
- assert_equal, assert_allclose)
+from numpy.testing import assert_raises, assert_equal, assert_allclose
import pywt
@@ -108,7 +107,3 @@ def test_default_mode():
assert_allclose(cA, cA2)
assert_allclose(cD, cD2)
assert_allclose(pywt.idwt(cA, cD, 'db2'), x)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_multidim.py b/pywt/tests/test_multidim.py
index 5594933c9..c04c9a57f 100644
--- a/pywt/tests/test_multidim.py
+++ b/pywt/tests/test_multidim.py
@@ -4,8 +4,7 @@
import numpy as np
from itertools import combinations
-from numpy.testing import (run_module_suite, assert_allclose, assert_,
- assert_raises, assert_equal)
+from numpy.testing import assert_allclose, assert_, assert_raises, assert_equal
import pywt
# Check that float32, float64, complex64, complex128 are preserved.
@@ -442,7 +441,3 @@ def test_error_on_continuous_wavelet():
c = dec_fun(data, 'db1')
assert_raises(ValueError, rec_fun, c, wavelet=cwave)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_multilevel.py b/pywt/tests/test_multilevel.py
index 7bf5d9df9..4223a56a7 100644
--- a/pywt/tests/test_multilevel.py
+++ b/pywt/tests/test_multilevel.py
@@ -5,9 +5,9 @@
import warnings
from itertools import combinations
import numpy as np
-from numpy.testing import (run_module_suite, assert_almost_equal,
- assert_allclose, assert_, assert_equal,
- assert_raises, assert_raises_regex, dec,
+import pytest
+from numpy.testing import (assert_almost_equal, assert_allclose, assert_,
+ assert_equal, assert_raises, assert_raises_regex,
assert_array_equal, assert_warns)
import pywt
# Check that float32, float64, complex64, complex128 are preserved.
@@ -176,7 +176,7 @@ def test_multilevel_dtypes_2d():
assert_(x_roundtrip.dtype == dt_out, "waverec2: " + errmsg)
-@dec.slow
+@pytest.mark.slow
def test_waverec2_all_wavelets_modes():
# test 2D case using all wavelets and modes
rstate = np.random.RandomState(1234)
@@ -363,7 +363,7 @@ def test_waverecn_dtypes():
assert_allclose(pywt.waverecn(coeffs, 'db1'), x, atol=tol, rtol=tol)
-@dec.slow
+@pytest.mark.slow
def test_waverecn_all_wavelets_modes():
# test 2D case using all wavelets and modes
rstate = np.random.RandomState(1234)
@@ -1013,7 +1013,3 @@ def test_waverec_mixed_precision():
r = ifunc(coeffs, 'db1')
assert_allclose(r, x, rtol=1e-7, atol=1e-7)
assert_equal(r.dtype, np.complex128)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_perfect_reconstruction.py b/pywt/tests/test_perfect_reconstruction.py
index 0466b8394..e57fa5532 100644
--- a/pywt/tests/test_perfect_reconstruction.py
+++ b/pywt/tests/test_perfect_reconstruction.py
@@ -7,7 +7,7 @@
from __future__ import division, print_function, absolute_import
import numpy as np
-from numpy.testing import assert_, run_module_suite
+from numpy.testing import assert_
import pywt
@@ -28,7 +28,7 @@ def test_perfect_reconstruction():
for wavelet in wavelets:
for pmode, mmode in modes:
for dt in dtypes:
- yield check_reconstruction, pmode, mmode, wavelet, dt
+ check_reconstruction(pmode, mmode, wavelet, dt)
def check_reconstruction(pmode, mmode, wavelet, dtype):
@@ -59,7 +59,3 @@ def check_reconstruction(pmode, mmode, wavelet, dtype):
msg = ('[RMS_REC > EPSILON] for Mode: %s, Wavelet: %s, '
'Length: %d, rms=%.3g' % (pmode, wavelet, len(data), rms_rec))
assert_(rms_rec < epsilon, msg=msg)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_swt.py b/pywt/tests/test_swt.py
index 0dfd9e01e..a0e4d2956 100644
--- a/pywt/tests/test_swt.py
+++ b/pywt/tests/test_swt.py
@@ -6,9 +6,9 @@
from copy import deepcopy
from itertools import combinations, permutations
import numpy as np
-from numpy.testing import (run_module_suite, dec, assert_allclose, assert_,
- assert_equal, assert_raises, assert_array_equal,
- assert_warns)
+import pytest
+from numpy.testing import (assert_allclose, assert_, assert_equal,
+ assert_raises, assert_array_equal, assert_warns)
import pywt
from pywt._extensions._swt import swt_axis
@@ -210,7 +210,7 @@ def test_swt2_ndim_error():
assert_raises(ValueError, pywt.swt2, x, 'haar', level=1)
-@dec.slow
+@pytest.mark.slow
def test_swt2_iswt2_integration(wavelets=None):
# This function performs a round-trip swt2/iswt2 transform test on
# all available types of wavelets in PyWavelets - except the
@@ -325,7 +325,7 @@ def test_swtn_axes():
start_level=0, axis=0)
-@dec.slow
+@pytest.mark.slow
def test_swtn_iswtn_integration(wavelets=None):
# This function performs a round-trip swtn/iswtn transform for various
# possible combinations of:
@@ -525,7 +525,3 @@ def test_iswtn_mixed_dtypes():
y = pywt.iswtn(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_thresholding.py b/pywt/tests/test_thresholding.py
index b618b94ae..abe69fadf 100644
--- a/pywt/tests/test_thresholding.py
+++ b/pywt/tests/test_thresholding.py
@@ -1,7 +1,6 @@
from __future__ import division, print_function, absolute_import
import numpy as np
-from numpy.testing import (assert_allclose, run_module_suite, assert_raises,
- assert_, assert_equal)
+from numpy.testing import assert_allclose, assert_raises, assert_, assert_equal
import pywt
@@ -168,7 +167,3 @@ def test_threshold_firm():
mt_abs_firm = np.abs(d_firm[mt])
assert_(np.all(mt_abs_firm < np.abs(d_hard[mt])))
assert_(np.all(mt_abs_firm > np.abs(d_soft[mt])))
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_wavelet.py b/pywt/tests/test_wavelet.py
index 99187f243..ba5a9265a 100644
--- a/pywt/tests/test_wavelet.py
+++ b/pywt/tests/test_wavelet.py
@@ -2,7 +2,7 @@
from __future__ import division, print_function, absolute_import
import numpy as np
-from numpy.testing import run_module_suite, assert_allclose, assert_
+from numpy.testing import assert_allclose, assert_
import pywt
@@ -54,11 +54,11 @@ def test_wavelet_coefficients():
wavelets = sum([pywt.wavelist(name) for name in families], [])
for wavelet in wavelets:
if (pywt.Wavelet(wavelet).orthogonal):
- yield check_coefficients_orthogonal, wavelet
+ check_coefficients_orthogonal(wavelet)
elif(pywt.Wavelet(wavelet).biorthogonal):
- yield check_coefficients_biorthogonal, wavelet
+ check_coefficients_biorthogonal(wavelet)
else:
- yield check_coefficients, wavelet
+ check_coefficients(wavelet)
def check_coefficients_orthogonal(wavelet):
@@ -264,7 +264,3 @@ def test_wavefun_bior13():
assert_allclose(phi_r, phi_r_expect, rtol=1e-10, atol=1e-12)
assert_allclose(psi_d, psi_d_expect, rtol=1e-10, atol=1e-12)
assert_allclose(psi_r, psi_r_expect, rtol=1e-10, atol=1e-12)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_wp.py b/pywt/tests/test_wp.py
index 50fae06ec..3c39e705e 100644
--- a/pywt/tests/test_wp.py
+++ b/pywt/tests/test_wp.py
@@ -3,8 +3,8 @@
from __future__ import division, print_function, absolute_import
import numpy as np
-from numpy.testing import (run_module_suite, assert_allclose, assert_,
- assert_raises, assert_equal)
+from numpy.testing import (assert_allclose, assert_, assert_raises,
+ assert_equal)
import pywt
@@ -195,7 +195,3 @@ def test_db3_roundtrip():
maxlevel=3)
r = wp.reconstruct()
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/pywt/tests/test_wp2d.py b/pywt/tests/test_wp2d.py
index de1305de2..96b6379a2 100644
--- a/pywt/tests/test_wp2d.py
+++ b/pywt/tests/test_wp2d.py
@@ -3,8 +3,8 @@
from __future__ import division, print_function, absolute_import
import numpy as np
-from numpy.testing import (run_module_suite, assert_allclose, assert_,
- assert_raises, assert_equal)
+from numpy.testing import (assert_allclose, assert_, assert_raises,
+ assert_equal)
import pywt
@@ -175,7 +175,3 @@ def test_2d_roundtrip():
maxlevel=3)
r = wp.reconstruct()
assert_allclose(original, r, atol=1e-12, rtol=1e-12)
-
-
-if __name__ == '__main__':
- run_module_suite()
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 24d5546cd..000000000
--- a/setup.cfg
+++ /dev/null
@@ -1,5 +0,0 @@
-[nosetests]
-with-coverage = true
-with-doctest = true
-# Don't ignore files starting with '_'
-ignore-files = (?:^\.|^util)
diff --git a/setup.py b/setup.py
index e965b8398..5a29b73e1 100755
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@
import setuptools
from setuptools import setup, Extension
-
+from setuptools.command.test import test as TestCommand
MAJOR = 1
MINOR = 1
@@ -360,6 +360,24 @@ def parse_setuppy_commands():
return True
+class PyTest(TestCommand):
+ user_options = [('pytest-args=', 'a', "Arguments to pass to py.test")]
+
+ def initialize_options(self):
+ TestCommand.initialize_options(self)
+ self.pytest_args = []
+
+ def finalize_options(self):
+ TestCommand.finalize_options(self)
+ self.test_args = []
+ self.test_suite = True
+
+ def run_tests(self):
+ #import here, cause outside the eggs aren't loaded
+ import pytest
+ errno = pytest.main(self.pytest_args)
+ sys.exit(errno)
+
def setup_package():
# Rewrite the version file everytime
@@ -410,8 +428,8 @@ def setup_package():
'pywt': ['tests/*.py', 'tests/data/*.npz',
'tests/data/*.py']},
libraries=[c_lib],
- cmdclass={'develop': develop_build_clib},
- test_suite='nose.collector',
+ cmdclass={'develop': develop_build_clib, 'test': PyTest},
+ tests_require=['pytest'],
install_requires=["numpy>=1.13.3"],
setup_requires=["numpy>=1.13.3"],
diff --git a/tox.ini b/tox.ini
index 33a7528ad..b7996ea5a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -25,14 +25,14 @@ envlist = py35, py36, py37
[testenv]
deps =
flake8
- nose
+ pytest
coverage
cython
numpy
matplotlib
changedir = {envdir}
commands =
- nosetests --tests {toxinidir}/pywt/tests
+ pytest {toxinidir}/pywt/tests -v
# flake8 --exit-zero pywt
[pep8]
diff --git a/util/readthedocs/requirements.txt b/util/readthedocs/requirements.txt
index db25531e3..490c91afa 100644
--- a/util/readthedocs/requirements.txt
+++ b/util/readthedocs/requirements.txt
@@ -1,6 +1,6 @@
numpy
cython
-nose
+pytest
wheel
numpydoc
matplotlib