diff --git a/pywt/_multilevel.py b/pywt/_multilevel.py index 1d847107a..423737e05 100644 --- a/pywt/_multilevel.py +++ b/pywt/_multilevel.py @@ -156,6 +156,12 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1): a, ds = coeffs[0], coeffs[1:] for d in ds: + if d is not None and not isinstance(d, np.ndarray): + raise ValueError(( + "Unexpected detail coefficient type: {}. Detail coefficients " + "must be arrays as returned by wavedec. If you are using " + "pywt.array_to_coeffs or pywt.unravel_coeffs, please specify " + "output_format='wavedec'").format(type(d))) if (a is not None) and (d is not None): try: if a.shape[axis] == d.shape[axis] + 1: @@ -164,10 +170,6 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1): raise ValueError("coefficient shape mismatch") except IndexError: raise ValueError("Axis greater than coefficient dimensions") - except AttributeError: - raise AttributeError( - "Wrong coefficient format, if using 'array_to_coeffs' " - "please specify the 'output_format' parameter") a = idwt(a, d, wavelet, mode, axis) return a @@ -310,6 +312,12 @@ def waverec2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)): a = np.asarray(a) for d in ds: + if not isinstance(d, (list, tuple)) or len(d) != 3: + raise ValueError(( + "Unexpected detail coefficient type: {}. Detail coefficients " + "must be a 3-tuple of arrays as returned by wavedec2. If you " + "are using pywt.array_to_coeffs or pywt.unravel_coeffs, " + "please specify output_format='wavedec2'").format(type(d))) d = tuple(np.asarray(coeff) if coeff is not None else None for coeff in d) d_shapes = (coeff.shape for coeff in d if coeff is not None) @@ -511,6 +519,14 @@ def waverecn(coeffs, wavelet, mode='symmetric', axes=None): a, ds = coeffs[0], coeffs[1:] + # this dictionary check must be prior to the call to _fix_coeffs + if len(ds) > 0 and not all([isinstance(d, dict) for d in ds]): + raise ValueError(( + "Unexpected detail coefficient type: {}. Detail coefficients " + "must be a dicionary of arrays as returned by wavedecn. If " + "you are using pywt.array_to_coeffs or pywt.unravel_coeffs, " + "please specify output_format='wavedecn'").format(type(ds[0]))) + # Raise error for invalid key combinations ds = list(map(_fix_coeffs, ds)) @@ -827,7 +843,8 @@ def array_to_coeffs(arr, coeff_slices, output_format='wavedecn'): >>> cam = pywt.data.camera() >>> coeffs = pywt.wavedecn(cam, wavelet='db2', level=3) >>> arr, coeff_slices = pywt.coeffs_to_array(coeffs) - >>> coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices) + >>> coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices, + ... output_format='wavedecn') >>> cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2') >>> assert_array_almost_equal(cam, cam_recon) @@ -1120,7 +1137,8 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'): >>> cam = pywt.data.camera() >>> coeffs = pywt.wavedecn(cam, wavelet='db2', level=3) >>> arr, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coeffs) - >>> coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes) + >>> coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes, + ... output_format='wavedecn') >>> cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2') >>> assert_array_almost_equal(cam, cam_recon) diff --git a/pywt/tests/test_multilevel.py b/pywt/tests/test_multilevel.py index 7bf5d9df9..f6fb9fa02 100644 --- a/pywt/tests/test_multilevel.py +++ b/pywt/tests/test_multilevel.py @@ -80,8 +80,9 @@ def test_waverec_invalid_inputs(): coeffs = pywt.wavedec(x, 'db1') arr, coeff_slices = pywt.coeffs_to_array(coeffs) coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices) - message = "Wrong coefficient format, if using 'array_to_coeffs' please specify the 'output_format' parameter" - assert_raises_regex(AttributeError, message, pywt.waverec, coeffs_from_arr, 'haar') + message = "Unexpected detail coefficient type" + assert_raises_regex(ValueError, message, pywt.waverec, coeffs_from_arr, + 'haar') def test_waverec_accuracies(): @@ -208,6 +209,13 @@ def test_waverec2_invalid_inputs(): # input list cannot be empty assert_raises(ValueError, pywt.waverec2, [], 'haar') + # coefficients from a difference decomposition used as input + for dec_func in [pywt.wavedec, pywt.wavedecn]: + coeffs = dec_func(np.ones((8, 8)), 'haar') + message = "Unexpected detail coefficient type" + assert_raises_regex(ValueError, message, pywt.waverec2, coeffs, + 'haar') + def test_waverec2_coeff_shape_mismatch(): x = np.ones((8, 8)) @@ -285,6 +293,16 @@ def test_waverecn_invalid_coeffs(): assert_raises(ValueError, pywt.waverecn, [], 'haar') +def test_waverecn_invalid_inputs(): + + # coefficients from a difference decomposition used as input + for dec_func in [pywt.wavedec, pywt.wavedec2]: + coeffs = dec_func(np.ones((8, 8)), 'haar') + message = "Unexpected detail coefficient type" + assert_raises_regex(ValueError, message, pywt.waverecn, coeffs, + 'haar') + + def test_waverecn_lists(): # support coefficient arrays specified as lists instead of arrays coeffs = [[[1.0]], {'ad': [[0.0]], 'da': [[0.0]], 'dd': [[0.0]]}]