Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pywt/_dwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,12 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
if cA is not None and cD is not None:
if cA.dtype != cD.dtype:
# need to upcast to common type
cA = cA.astype(np.float64)
cD = cD.astype(np.float64)
if cA.dtype.kind == 'c' or cD.dtype.kind == 'c':
dtype = np.complex128
else:
dtype = np.float64
cA = cA.astype(dtype)
cD = cD.astype(dtype)
elif cA is None:
cA = np.zeros_like(cD)
elif cD is None:
Expand Down
10 changes: 9 additions & 1 deletion pywt/_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,15 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
for key in new_keys:
L = coeffs.get(key + 'a', None)
H = coeffs.get(key + 'd', None)

if L is not None and H is not None:
if L.dtype != H.dtype:
# upcast to a common dtype (float64 or complex128)
if L.dtype.kind == 'c' or H.dtype.kind == 'c':
dtype = np.complex128
else:
dtype = np.float64
L = np.asarray(L, dtype=dtype)
H = np.asarray(H, dtype=dtype)
new_coeffs[key] = idwt_axis(L, H, wav, mode, axis)
coeffs = new_coeffs

Expand Down
38 changes: 33 additions & 5 deletions pywt/_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def iswt(coeffs, wavelet):
>>> pywt.iswt(coeffs, 'db2')
array([ 1., 2., 3., 4., 5., 6., 7., 8.])
"""
output = coeffs[0][0].copy() # Avoid modification of input data
# copy to avoid modification of input data
dt = _check_dtype(coeffs[0][0])
output = np.array(coeffs[0][0], dtype=dt, copy=True)

if not _have_c99_complex and np.iscomplexobj(output):
# compute real and imaginary separately then combine
coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs]
Expand All @@ -128,8 +131,15 @@ def iswt(coeffs, wavelet):
step_size = int(pow(2, j-1))
last_index = step_size
_, cD = coeffs[num_levels - j]
dt = _check_dtype(cD)
cD = np.asarray(cD, dtype=dt) # doesn't copy if dtype matches
cD = np.asarray(cD, dtype=_check_dtype(cD))
if cD.dtype != output.dtype:
# upcast to a common dtype (float64 or complex128)
if output.dtype.kind == 'c' or cD.dtype.kind == 'c':
dtype = np.complex128
else:
dtype = np.float64
output = np.asarray(output, dtype=dtype)
cD = np.asarray(cD, dtype=dtype)
for first in range(last_index): # 0 to last_index - 1

# Getting the indices that we will transform
Expand Down Expand Up @@ -271,7 +281,10 @@ def iswt2(coeffs, wavelet):

"""

output = coeffs[0][0].copy() # Avoid modification of input data
# copy to avoid modification of input data
dt = _check_dtype(coeffs[0][0])
output = np.array(coeffs[0][0], dtype=dt, copy=True)

if output.ndim != 2:
raise ValueError(
"iswt2 only supports 2D arrays. see iswtn for a general "
Expand All @@ -288,6 +301,14 @@ def iswt2(coeffs, wavelet):
if (cH.shape != cV.shape) or (cH.shape != cD.shape):
raise RuntimeError(
"Mismatch in shape of intermediate coefficient arrays")

# make sure output shares the common dtype
# (conversion of dtype for individual coeffs is handled within idwt2 )
common_dtype = np.result_type(*(
[dt, ] + [_check_dtype(c) for c in [cH, cV, cD]]))
if output.dtype != common_dtype:
output = output.astype(common_dtype)

for first_h in range(last_index): # 0 to last_index - 1
for first_w in range(last_index): # 0 to last_index - 1
# Getting the indices that we will transform
Expand Down Expand Up @@ -460,7 +481,9 @@ def iswtn(coeffs, wavelet, axes=None):
# key length matches the number of axes transformed
ndim_transform = max(len(key) for key in coeffs[0].keys())

output = coeffs[0]['a'*ndim_transform].copy() # Avoid modifying input data
# copy to avoid modification of input data
dt = _check_dtype(coeffs[0]['a'*ndim_transform])
output = np.array(coeffs[0]['a'*ndim_transform], dtype=dt, copy=True)
ndim = output.ndim

if axes is None:
Expand Down Expand Up @@ -488,6 +511,11 @@ def iswtn(coeffs, wavelet, axes=None):
last_index = step_size
a = coeffs[j].pop('a'*ndim_transform) # will restore later
details = coeffs[j]
# make sure dtype matches the coarsest level approximation coefficients
common_dtype = np.result_type(*(
[dt, ] + [v.dtype for v in details.values()]))
if output.dtype != common_dtype:
output = output.astype(common_dtype)
# We assume all coefficient arrays are of equal size
shapes = [v.shape for k, v in details.items()]
dshape = shapes[0]
Expand Down
17 changes: 16 additions & 1 deletion pywt/tests/test_dwt_idwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,22 @@ def test_dwt_idwt_basic():
x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32),
'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip.dtype == np.float64)
assert_(x_roundtrip2.dtype == np.float64)


def test_idwt_mixed_complex_dtype():
x = np.arange(8).astype(float)
x = x + 1j*x[::-1]
cA, cD = pywt.dwt(x, 'db2')

x_roundtrip = pywt.idwt(cA, cD, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)

# mismatched dtypes OK
x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64),
'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip2.dtype == np.complex128)


def test_dwt_idwt_dtypes():
Expand Down
16 changes: 16 additions & 0 deletions pywt/tests/test_multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,22 @@ def test_dwtn_idwtn_dtypes():
assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg)


def test_idwtn_mixed_complex_dtype():
rstate = np.random.RandomState(0)
x = rstate.randn(8, 8, 8)
x = x + 1j*x
coeffs = pywt.dwtn(x, 'db2')

x_roundtrip = pywt.idwtn(coeffs, 'db2')
assert_allclose(x_roundtrip, x, rtol=1e-10)

# mismatched dtypes OK
coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64)
x_roundtrip2 = pywt.idwtn(coeffs, 'db2')
assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7)
assert_(x_roundtrip2.dtype == np.complex128)


def test_idwt2_size_mismatch_error():
LL = np.zeros((6, 6))
LH = HL = HH = np.zeros((5, 5))
Expand Down
39 changes: 39 additions & 0 deletions pywt/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,5 +976,44 @@ def test_default_level():
pywt.dwt_max_level(data.shape[ax], wavelet[ax]))


def test_waverec_mixed_precision():
rstate = np.random.RandomState(0)
for func, ifunc, shape in [(pywt.wavedec, pywt.waverec, (8, )),
(pywt.wavedec2, pywt.waverec2, (8, 8)),
(pywt.wavedecn, pywt.waverecn, (8, 8, 8))]:
x = rstate.randn(*shape)
coeffs_real = func(x, 'db1')

# real: single precision approx, double precision details
coeffs_real[0] = coeffs_real[0].astype(np.float32)
r = ifunc(coeffs_real, 'db1')
assert_allclose(r, x, rtol=1e-7, atol=1e-7)
assert_equal(r.dtype, np.float64)

x = x + 1j*x
coeffs = func(x, 'db1')

# complex: single precision approx, double precision details
coeffs[0] = coeffs[0].astype(np.complex64)
r = ifunc(coeffs, 'db1')
assert_allclose(r, x, rtol=1e-7, atol=1e-7)
assert_equal(r.dtype, np.complex128)

# complex: double precision approx, single precision details
if x.ndim == 1:
coeffs[0] = coeffs[0].astype(np.complex128)
coeffs[1] = coeffs[1].astype(np.complex64)
if x.ndim == 2:
coeffs[0] = coeffs[0].astype(np.complex128)
coeffs[1] = tuple([v.astype(np.complex64) for v in coeffs[1]])
if x.ndim == 3:
coeffs[0] = coeffs[0].astype(np.complex128)
coeffs[1] = {k: v.astype(np.complex64)
for k, v in coeffs[1].items()}
r = ifunc(coeffs, 'db1')
assert_allclose(r, x, rtol=1e-7, atol=1e-7)
assert_equal(r.dtype, np.complex128)


if __name__ == '__main__':
run_module_suite()
86 changes: 86 additions & 0 deletions pywt/tests/test_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import pywt
from pywt._extensions._swt import swt_axis
from pywt._extensions._pywt import _check_dtype

# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
Expand Down Expand Up @@ -427,5 +428,90 @@ def test_error_on_continuous_wavelet():
assert_raises(ValueError, rec_func, c, wavelet=cwave)


def test_iswt_mixed_dtypes():
# Mixed precision inputs give double precision output
x_real = np.arange(16).astype(np.float64)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:

if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64

coeffs = pywt.swt(x, wav, 2)
# different precision for the approximation coefficients
coeffs[0] = [coeffs[0][0].astype(dtype1),
coeffs[0][1].astype(dtype2)]
y = pywt.iswt(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)


def test_iswt2_mixed_dtypes():
# Mixed precision inputs give double precision output
rstate = np.random.RandomState(0)
x_real = rstate.randn(8, 8)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:

if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64

coeffs = pywt.swt2(x, wav, 2)
# different precision for the approximation coefficients
coeffs[0] = [coeffs[0][0].astype(dtype1),
tuple([c.astype(dtype2) for c in coeffs[0][1]])]
y = pywt.iswt2(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)


def test_iswtn_mixed_dtypes():
# Mixed precision inputs give double precision output
rstate = np.random.RandomState(0)
x_real = rstate.randn(8, 8, 8)
x_complex = x_real + 1j*x_real
wav = 'sym2'
for dtype1, dtype2 in [(np.float64, np.float32),
(np.float32, np.float64),
(np.float16, np.float64),
(np.complex128, np.complex64),
(np.complex64, np.complex128)]:

if dtype1 in [np.complex64, np.complex128]:
x = x_complex
output_dtype = np.complex128
else:
x = x_real
output_dtype = np.float64

coeffs = pywt.swtn(x, wav, 2)
# different precision for the approximation coefficients
a = coeffs[0].pop('a' * x.ndim)
a = a.astype(dtype1)
coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
coeffs[0]['a' * x.ndim] = a
y = pywt.iswtn(coeffs, wav)
assert_equal(output_dtype, y.dtype)
assert_allclose(y, x, rtol=1e-3, atol=1e-3)


if __name__ == '__main__':
run_module_suite()