From 5fdd6bda8fd279f542ac6d309f8f91f0348f1c22 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 31 Jan 2019 17:10:35 -0500 Subject: [PATCH 1/7] use _check_dtype on the coarsest level approximation coeffs output array dtype is upcast if any detail coefficients have higher precision than the approximation coefficients. --- pywt/_swt.py | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/pywt/_swt.py b/pywt/_swt.py index 1fe2c79e4..ea29f3ff6 100644 --- a/pywt/_swt.py +++ b/pywt/_swt.py @@ -113,7 +113,10 @@ def iswt(coeffs, wavelet): >>> pywt.iswt(coeffs, 'db2') array([ 1., 2., 3., 4., 5., 6., 7., 8.]) """ - output = coeffs[0][0].copy() # Avoid modification of input data + # copy to avoid modification of input data + dt = _check_dtype(coeffs[0][0]) + output = np.array(coeffs[0][0], dtype=dt, copy=True) + if not _have_c99_complex and np.iscomplexobj(output): # compute real and imaginary separately then combine coeffs_real = [(cA.real, cD.real) for (cA, cD) in coeffs] @@ -128,8 +131,15 @@ def iswt(coeffs, wavelet): step_size = int(pow(2, j-1)) last_index = step_size _, cD = coeffs[num_levels - j] - dt = _check_dtype(cD) - cD = np.asarray(cD, dtype=dt) # doesn't copy if dtype matches + cD = np.asarray(cD, dtype=_check_dtype(cD)) + if cD.dtype != output.dtype: + # upcast to a common dtype (float64 or complex128) + if output.dtype.kind == 'c' or cD.dtype.kind == 'c': + dtype = np.complex128 + else: + dtype = np.float64 + output = np.asarray(output, dtype=dtype) + cD = np.asarray(cD, dtype=dtype) for first in range(last_index): # 0 to last_index - 1 # Getting the indices that we will transform @@ -271,7 +281,10 @@ def iswt2(coeffs, wavelet): """ - output = coeffs[0][0].copy() # Avoid modification of input data + # copy to avoid modification of input data + dt = _check_dtype(coeffs[0][0]) + output = np.array(coeffs[0][0], dtype=dt, copy=True) + if output.ndim != 2: raise ValueError( "iswt2 only supports 2D arrays. see iswtn for a general " @@ -288,6 +301,14 @@ def iswt2(coeffs, wavelet): if (cH.shape != cV.shape) or (cH.shape != cD.shape): raise RuntimeError( "Mismatch in shape of intermediate coefficient arrays") + + # make sure output shares the common dtype + # (conversion of dtype for individual coeffs is handled within idwt2 ) + common_dtype = np.result_type(*( + [dt, ] + [_check_dtype(c) for c in [cH, cV, cD]])) + if output.dtype != common_dtype: + output = output.astype(common_dtype) + for first_h in range(last_index): # 0 to last_index - 1 for first_w in range(last_index): # 0 to last_index - 1 # Getting the indices that we will transform @@ -460,7 +481,9 @@ def iswtn(coeffs, wavelet, axes=None): # key length matches the number of axes transformed ndim_transform = max(len(key) for key in coeffs[0].keys()) - output = coeffs[0]['a'*ndim_transform].copy() # Avoid modifying input data + # copy to avoid modification of input data + dt = _check_dtype(coeffs[0]['a'*ndim_transform]) + output = np.array(coeffs[0]['a'*ndim_transform], dtype=dt, copy=True) ndim = output.ndim if axes is None: @@ -488,6 +511,11 @@ def iswtn(coeffs, wavelet, axes=None): last_index = step_size a = coeffs[j].pop('a'*ndim_transform) # will restore later details = coeffs[j] + # make sure dtype matches the coarsest level approximation coefficients + common_dtype = np.result_type(*( + [dt, ] + [v.dtype for v in details.values()])) + if output.dtype != common_dtype: + output = output.astype(common_dtype) # We assume all coefficient arrays are of equal size shapes = [v.shape for k, v in details.items()] dshape = shapes[0] From 8fc0e701974941e14d36055ad357e208c71c2fe1 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 31 Jan 2019 18:01:29 -0500 Subject: [PATCH 2/7] fix: make dtype upcast in idwt respect complex dtypes --- pywt/_dwt.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pywt/_dwt.py b/pywt/_dwt.py index a780f7d65..56114566a 100644 --- a/pywt/_dwt.py +++ b/pywt/_dwt.py @@ -264,8 +264,12 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1): if cA is not None and cD is not None: if cA.dtype != cD.dtype: # need to upcast to common type - cA = cA.astype(np.float64) - cD = cD.astype(np.float64) + if cA.dtype.kind == 'c' or cD.dtype.kind == 'c': + dtype = np.complex128 + else: + dtype = np.float64 + cA = cA.astype(dtype) + cD = cD.astype(dtype) elif cA is None: cA = np.zeros_like(cD) elif cD is None: From afa2eb196691b23012ca9c853a1d1cfa49a066da Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 31 Jan 2019 18:02:25 -0500 Subject: [PATCH 3/7] add a dtype upcast to idwtn to match the behavior of idwt --- pywt/_multidim.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pywt/_multidim.py b/pywt/_multidim.py index ae70e4994..783da8a54 100644 --- a/pywt/_multidim.py +++ b/pywt/_multidim.py @@ -296,7 +296,14 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None): for key in new_keys: L = coeffs.get(key + 'a', None) H = coeffs.get(key + 'd', None) - + if L.dtype != H.dtype: + # upcast to a common dtype (float64 or complex128) + if L.dtype.kind == 'c' or H.dtype.kind == 'c': + dtype = np.complex128 + else: + dtype = np.float64 + L = np.asarray(L, dtype=dtype) + H = np.asarray(H, dtype=dtype) new_coeffs[key] = idwt_axis(L, H, wav, mode, axis) coeffs = new_coeffs From 0395eccd614130766ec0818a0a53f6db9b81548c Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 31 Jan 2019 17:11:03 -0500 Subject: [PATCH 4/7] test inverse SWT functions with mixed dtype coefficients --- pywt/tests/test_swt.py | 86 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/pywt/tests/test_swt.py b/pywt/tests/test_swt.py index 478d28482..cf9fc32e8 100644 --- a/pywt/tests/test_swt.py +++ b/pywt/tests/test_swt.py @@ -12,6 +12,7 @@ import pywt from pywt._extensions._swt import swt_axis +from pywt._extensions._pywt import _check_dtype # Check that float32 and complex64 are preserved. Other real types get # converted to float64. @@ -427,5 +428,90 @@ def test_error_on_continuous_wavelet(): assert_raises(ValueError, rec_func, c, wavelet=cwave) +def test_iswt_mixed_dtypes(): + # Mixed precision inputs give double precision output + x_real = np.arange(16).astype(np.float64) + x_complex = x_real + 1j*x_real + wav = 'sym2' + for dtype1, dtype2 in [(np.float64, np.float32), + (np.float32, np.float64), + (np.float16, np.float64), + (np.complex128, np.complex64), + (np.complex64, np.complex128)]: + + if dtype1 in [np.complex64, np.complex128]: + x = x_complex + output_dtype = np.complex128 + else: + x = x_real + output_dtype = np.float64 + + coeffs = pywt.swt(x, wav, 2) + # different precision for the approximation coefficients + coeffs[0] = [coeffs[0][0].astype(dtype1), + coeffs[0][1].astype(dtype2)] + y = pywt.iswt(coeffs, wav) + assert_equal(output_dtype, y.dtype) + assert_allclose(y, x, rtol=1e-3, atol=1e-3) + + +def test_iswt2_mixed_dtypes(): + # Mixed precision inputs give double precision output + rstate = np.random.RandomState(0) + x_real = rstate.randn(8, 8) + x_complex = x_real + 1j*x_real + wav = 'sym2' + for dtype1, dtype2 in [(np.float64, np.float32), + (np.float32, np.float64), + (np.float16, np.float64), + (np.complex128, np.complex64), + (np.complex64, np.complex128)]: + + if dtype1 in [np.complex64, np.complex128]: + x = x_complex + output_dtype = np.complex128 + else: + x = x_real + output_dtype = np.float64 + + coeffs = pywt.swt2(x, wav, 2) + # different precision for the approximation coefficients + coeffs[0] = [coeffs[0][0].astype(dtype1), + tuple([c.astype(dtype2) for c in coeffs[0][1]])] + y = pywt.iswt2(coeffs, wav) + assert_equal(output_dtype, y.dtype) + assert_allclose(y, x, rtol=1e-3, atol=1e-3) + + +def test_iswtn_mixed_dtypes(): + # Mixed precision inputs give double precision output + rstate = np.random.RandomState(0) + x_real = rstate.randn(8, 8, 8) + x_complex = x_real + 1j*x_real + wav = 'sym2' + for dtype1, dtype2 in [(np.float64, np.float32), + (np.float32, np.float64), + (np.float16, np.float64), + (np.complex128, np.complex64), + (np.complex64, np.complex128)]: + + if dtype1 in [np.complex64, np.complex128]: + x = x_complex + output_dtype = np.complex128 + else: + x = x_real + output_dtype = np.float64 + + coeffs = pywt.swtn(x, wav, 2) + # different precision for the approximation coefficients + a = coeffs[0].pop('a' * x.ndim) + a = a.astype(dtype1) + coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()} + coeffs[0]['a' * x.ndim] = a + y = pywt.iswtn(coeffs, wav) + assert_equal(output_dtype, y.dtype) + assert_allclose(y, x, rtol=1e-3, atol=1e-3) + + if __name__ == '__main__': run_module_suite() From b897d66781c089683eb78abdbe0f46887b7d663c Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 31 Jan 2019 18:12:44 -0500 Subject: [PATCH 5/7] test idwt and idwtn with mixed complex dtypes as well --- pywt/tests/test_dwt_idwt.py | 17 ++++++++++++++++- pywt/tests/test_multidim.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pywt/tests/test_dwt_idwt.py b/pywt/tests/test_dwt_idwt.py index fe33621d7..4d0de63a4 100644 --- a/pywt/tests/test_dwt_idwt.py +++ b/pywt/tests/test_dwt_idwt.py @@ -39,7 +39,22 @@ def test_dwt_idwt_basic(): x_roundtrip2 = pywt.idwt(cA.astype(np.float64), cD.astype(np.float32), 'db2') assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7) - assert_(x_roundtrip.dtype == np.float64) + assert_(x_roundtrip2.dtype == np.float64) + + +def test_idwt_mixed_complex_dtype(): + x = np.arange(8).astype(float) + x = x + 1j*x[::-1] + cA, cD = pywt.dwt(x, 'db2') + + x_roundtrip = pywt.idwt(cA, cD, 'db2') + assert_allclose(x_roundtrip, x, rtol=1e-10) + + # mismatched dtypes OK + x_roundtrip2 = pywt.idwt(cA.astype(np.complex128), cD.astype(np.complex64), + 'db2') + assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7) + assert_(x_roundtrip2.dtype == np.complex128) def test_dwt_idwt_dtypes(): diff --git a/pywt/tests/test_multidim.py b/pywt/tests/test_multidim.py index 4baf2934d..5594933c9 100644 --- a/pywt/tests/test_multidim.py +++ b/pywt/tests/test_multidim.py @@ -363,6 +363,22 @@ def test_dwtn_idwtn_dtypes(): assert_(x_roundtrip.dtype == dt_out, "idwtn: " + errmsg) +def test_idwtn_mixed_complex_dtype(): + rstate = np.random.RandomState(0) + x = rstate.randn(8, 8, 8) + x = x + 1j*x + coeffs = pywt.dwtn(x, 'db2') + + x_roundtrip = pywt.idwtn(coeffs, 'db2') + assert_allclose(x_roundtrip, x, rtol=1e-10) + + # mismatched dtypes OK + coeffs['a' * x.ndim] = coeffs['a' * x.ndim].astype(np.complex64) + x_roundtrip2 = pywt.idwtn(coeffs, 'db2') + assert_allclose(x_roundtrip2, x, rtol=1e-7, atol=1e-7) + assert_(x_roundtrip2.dtype == np.complex128) + + def test_idwt2_size_mismatch_error(): LL = np.zeros((6, 6)) LH = HL = HH = np.zeros((5, 5)) From cc6ea9dbe1bee69bdd0ceb27269d622359af3493 Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 31 Jan 2019 18:39:46 -0500 Subject: [PATCH 6/7] test mixed precision with wavedec/waverec as well --- pywt/tests/test_multilevel.py | 39 +++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/pywt/tests/test_multilevel.py b/pywt/tests/test_multilevel.py index feb398244..7bf5d9df9 100644 --- a/pywt/tests/test_multilevel.py +++ b/pywt/tests/test_multilevel.py @@ -976,5 +976,44 @@ def test_default_level(): pywt.dwt_max_level(data.shape[ax], wavelet[ax])) +def test_waverec_mixed_precision(): + rstate = np.random.RandomState(0) + for func, ifunc, shape in [(pywt.wavedec, pywt.waverec, (8, )), + (pywt.wavedec2, pywt.waverec2, (8, 8)), + (pywt.wavedecn, pywt.waverecn, (8, 8, 8))]: + x = rstate.randn(*shape) + coeffs_real = func(x, 'db1') + + # real: single precision approx, double precision details + coeffs_real[0] = coeffs_real[0].astype(np.float32) + r = ifunc(coeffs_real, 'db1') + assert_allclose(r, x, rtol=1e-7, atol=1e-7) + assert_equal(r.dtype, np.float64) + + x = x + 1j*x + coeffs = func(x, 'db1') + + # complex: single precision approx, double precision details + coeffs[0] = coeffs[0].astype(np.complex64) + r = ifunc(coeffs, 'db1') + assert_allclose(r, x, rtol=1e-7, atol=1e-7) + assert_equal(r.dtype, np.complex128) + + # complex: double precision approx, single precision details + if x.ndim == 1: + coeffs[0] = coeffs[0].astype(np.complex128) + coeffs[1] = coeffs[1].astype(np.complex64) + if x.ndim == 2: + coeffs[0] = coeffs[0].astype(np.complex128) + coeffs[1] = tuple([v.astype(np.complex64) for v in coeffs[1]]) + if x.ndim == 3: + coeffs[0] = coeffs[0].astype(np.complex128) + coeffs[1] = {k: v.astype(np.complex64) + for k, v in coeffs[1].items()} + r = ifunc(coeffs, 'db1') + assert_allclose(r, x, rtol=1e-7, atol=1e-7) + assert_equal(r.dtype, np.complex128) + + if __name__ == '__main__': run_module_suite() From fb0fa8c70ae1df4778d0b8a38d716ceba7f67b4a Mon Sep 17 00:00:00 2001 From: "Gregory R. Lee" Date: Thu, 31 Jan 2019 18:41:55 -0500 Subject: [PATCH 7/7] idwtn: only compare dtypes if both coeffs are not None --- pywt/_multidim.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pywt/_multidim.py b/pywt/_multidim.py index 783da8a54..39d9dc2bf 100644 --- a/pywt/_multidim.py +++ b/pywt/_multidim.py @@ -296,14 +296,15 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None): for key in new_keys: L = coeffs.get(key + 'a', None) H = coeffs.get(key + 'd', None) - if L.dtype != H.dtype: - # upcast to a common dtype (float64 or complex128) - if L.dtype.kind == 'c' or H.dtype.kind == 'c': - dtype = np.complex128 - else: - dtype = np.float64 - L = np.asarray(L, dtype=dtype) - H = np.asarray(H, dtype=dtype) + if L is not None and H is not None: + if L.dtype != H.dtype: + # upcast to a common dtype (float64 or complex128) + if L.dtype.kind == 'c' or H.dtype.kind == 'c': + dtype = np.complex128 + else: + dtype = np.float64 + L = np.asarray(L, dtype=dtype) + H = np.asarray(H, dtype=dtype) new_coeffs[key] = idwt_axis(L, H, wav, mode, axis) coeffs = new_coeffs