Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions pywt/_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1):
a, ds = coeffs[0], coeffs[1:]

for d in ds:
if d is not None and not isinstance(d, np.ndarray):
raise ValueError((
"Unexpected detail coefficient type: {}. Detail coefficients "
"must be arrays as returned by wavedec. If you are using "
"pywt.array_to_coeffs or pywt.unravel_coeffs, please specify "
"output_format='wavedec'").format(type(d)))
if (a is not None) and (d is not None):
try:
if a.shape[axis] == d.shape[axis] + 1:
Expand All @@ -164,10 +170,6 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1):
raise ValueError("coefficient shape mismatch")
except IndexError:
raise ValueError("Axis greater than coefficient dimensions")
except AttributeError:
raise AttributeError(
"Wrong coefficient format, if using 'array_to_coeffs' "
"please specify the 'output_format' parameter")
a = idwt(a, d, wavelet, mode, axis)

return a
Expand Down Expand Up @@ -310,6 +312,12 @@ def waverec2(coeffs, wavelet, mode='symmetric', axes=(-2, -1)):
a = np.asarray(a)

for d in ds:
if not isinstance(d, (list, tuple)) or len(d) != 3:
raise ValueError((
"Unexpected detail coefficient type: {}. Detail coefficients "
"must be a 3-tuple of arrays as returned by wavedec2. If you "
"are using pywt.array_to_coeffs or pywt.unravel_coeffs, "
"please specify output_format='wavedec2'").format(type(d)))
d = tuple(np.asarray(coeff) if coeff is not None else None
for coeff in d)
d_shapes = (coeff.shape for coeff in d if coeff is not None)
Expand Down Expand Up @@ -511,6 +519,14 @@ def waverecn(coeffs, wavelet, mode='symmetric', axes=None):

a, ds = coeffs[0], coeffs[1:]

# this dictionary check must be prior to the call to _fix_coeffs
if len(ds) > 0 and not all([isinstance(d, dict) for d in ds]):
raise ValueError((
"Unexpected detail coefficient type: {}. Detail coefficients "
"must be a dicionary of arrays as returned by wavedecn. If "
"you are using pywt.array_to_coeffs or pywt.unravel_coeffs, "
"please specify output_format='wavedecn'").format(type(ds[0])))

# Raise error for invalid key combinations
ds = list(map(_fix_coeffs, ds))

Expand Down Expand Up @@ -827,7 +843,8 @@ def array_to_coeffs(arr, coeff_slices, output_format='wavedecn'):
>>> cam = pywt.data.camera()
>>> coeffs = pywt.wavedecn(cam, wavelet='db2', level=3)
>>> arr, coeff_slices = pywt.coeffs_to_array(coeffs)
>>> coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices)
>>> coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices,
... output_format='wavedecn')
>>> cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2')
>>> assert_array_almost_equal(cam, cam_recon)

Expand Down Expand Up @@ -1120,7 +1137,8 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'):
>>> cam = pywt.data.camera()
>>> coeffs = pywt.wavedecn(cam, wavelet='db2', level=3)
>>> arr, coeff_slices, coeff_shapes = pywt.ravel_coeffs(coeffs)
>>> coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes)
>>> coeffs_from_arr = pywt.unravel_coeffs(arr, coeff_slices, coeff_shapes,
... output_format='wavedecn')
>>> cam_recon = pywt.waverecn(coeffs_from_arr, wavelet='db2')
>>> assert_array_almost_equal(cam, cam_recon)

Expand Down
22 changes: 20 additions & 2 deletions pywt/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def test_waverec_invalid_inputs():
coeffs = pywt.wavedec(x, 'db1')
arr, coeff_slices = pywt.coeffs_to_array(coeffs)
coeffs_from_arr = pywt.array_to_coeffs(arr, coeff_slices)
message = "Wrong coefficient format, if using 'array_to_coeffs' please specify the 'output_format' parameter"
assert_raises_regex(AttributeError, message, pywt.waverec, coeffs_from_arr, 'haar')
message = "Unexpected detail coefficient type"
assert_raises_regex(ValueError, message, pywt.waverec, coeffs_from_arr,
'haar')


def test_waverec_accuracies():
Expand Down Expand Up @@ -208,6 +209,13 @@ def test_waverec2_invalid_inputs():
# input list cannot be empty
assert_raises(ValueError, pywt.waverec2, [], 'haar')

# coefficients from a difference decomposition used as input
for dec_func in [pywt.wavedec, pywt.wavedecn]:
coeffs = dec_func(np.ones((8, 8)), 'haar')
message = "Unexpected detail coefficient type"
assert_raises_regex(ValueError, message, pywt.waverec2, coeffs,
'haar')


def test_waverec2_coeff_shape_mismatch():
x = np.ones((8, 8))
Expand Down Expand Up @@ -285,6 +293,16 @@ def test_waverecn_invalid_coeffs():
assert_raises(ValueError, pywt.waverecn, [], 'haar')


def test_waverecn_invalid_inputs():

# coefficients from a difference decomposition used as input
for dec_func in [pywt.wavedec, pywt.wavedec2]:
coeffs = dec_func(np.ones((8, 8)), 'haar')
message = "Unexpected detail coefficient type"
assert_raises_regex(ValueError, message, pywt.waverecn, coeffs,
'haar')


def test_waverecn_lists():
# support coefficient arrays specified as lists instead of arrays
coeffs = [[[1.0]], {'ad': [[0.0]], 'da': [[0.0]], 'dd': [[0.0]]}]
Expand Down