diff --git a/pywt/_wavelet_packets.py b/pywt/_wavelet_packets.py index 755179c41..a8d60cc34 100644 --- a/pywt/_wavelet_packets.py +++ b/pywt/_wavelet_packets.py @@ -68,6 +68,12 @@ def __init__(self, parent, data, node_name): # data - signal on level 0, coeffs on higher levels self.data = data + # Need to retain original data size/shape so we can trim any excess + # boundary coefficients from the inverse transform. + if self.data is None: + self._data_shape = None + else: + self._data_shape = np.asarray(data).shape self._init_subnodes() @@ -436,6 +442,9 @@ def _reconstruct(self, update): " from subnodes.") else: rec = idwt(data_a, data_d, self.wavelet, self.mode) + if self._data_shape is not None and ( + rec.shape != self._data_shape): + rec = rec[tuple([slice(sz) for sz in self._data_shape])] if update: self.data = rec return rec @@ -504,6 +513,9 @@ def _reconstruct(self, update): else: coeffs = data_ll, (data_hl, data_lh, data_hh) rec = idwt2(coeffs, self.wavelet, self.mode) + if self._data_shape is not None and ( + rec.shape != self._data_shape): + rec = rec[tuple([slice(sz) for sz in self._data_shape])] if update: self.data = rec return rec @@ -568,8 +580,6 @@ def reconstruct(self, update=True): """ if self.has_any_subnode: data = super(WaveletPacket, self).reconstruct(update) - if self.data_size is not None and len(data) > self.data_size: - data = data[:self.data_size] if update: self.data = data return data @@ -669,8 +679,6 @@ def reconstruct(self, update=True): """ if self.has_any_subnode: data = super(WaveletPacket2D, self).reconstruct(update) - if self.data_size is not None and (data.shape != self.data_size): - data = data[:self.data_size[0], :self.data_size[1]] if update: self.data = data return data diff --git a/pywt/tests/test_wp.py b/pywt/tests/test_wp.py index 9025b00e1..50fae06ec 100644 --- a/pywt/tests/test_wp.py +++ b/pywt/tests/test_wp.py @@ -189,5 +189,13 @@ def test_wavelet_packet_dtypes(): assert_allclose(r, x.astype(transform_dtype), atol=1e-5, rtol=1e-5) +def test_db3_roundtrip(): + original = np.arange(512) + wp = pywt.WaveletPacket(data=original, wavelet='db3', mode='smooth', + maxlevel=3) + r = wp.reconstruct() + assert_allclose(original, r, atol=1e-12, rtol=1e-12) + + if __name__ == '__main__': run_module_suite() diff --git a/pywt/tests/test_wp2d.py b/pywt/tests/test_wp2d.py index ee6c1536f..de1305de2 100644 --- a/pywt/tests/test_wp2d.py +++ b/pywt/tests/test_wp2d.py @@ -168,5 +168,14 @@ def test_wavelet_packet_dtypes(): assert_allclose(r, x, atol=1e-5, rtol=1e-5) +def test_2d_roundtrip(): + # test case corresponding to PyWavelets issue 447 + original = pywt.data.camera() + wp = pywt.WaveletPacket2D(data=original, wavelet='db3', mode='smooth', + maxlevel=3) + r = wp.reconstruct() + assert_allclose(original, r, atol=1e-12, rtol=1e-12) + + if __name__ == '__main__': run_module_suite()