Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pywt/_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,16 +516,19 @@ def iswtn(coeffs, wavelet, axes=None):
[dt, ] + [v.dtype for v in details.values()]))
if output.dtype != common_dtype:
output = output.astype(common_dtype)

# We assume all coefficient arrays are of equal size
shapes = [v.shape for k, v in details.items()]
dshape = shapes[0]
if len(set(shapes)) != 1:
raise RuntimeError(
"Mismatch in shape of intermediate coefficient arrays")

# shape of a single coefficient array, excluding non-transformed axes
coeff_trans_shape = tuple([shapes[0][ax] for ax in axes])

# nested loop over all combinations of axis offsets at this level
for firsts in product(*([range(last_index), ]*ndim_transform)):
for first, sh, ax in zip(firsts, dshape, axes):
for first, sh, ax in zip(firsts, coeff_trans_shape, axes):
indices[ax] = slice(first, sh, step_size)
even_indices[ax] = slice(first, sh, 2*step_size)
odd_indices[ax] = slice(first+step_size, sh, 2*step_size)
Expand Down
18 changes: 16 additions & 2 deletions pywt/tests/test_swt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

import warnings
from copy import deepcopy
from itertools import combinations
from itertools import combinations, permutations
import numpy as np
from numpy.testing import (run_module_suite, dec, assert_allclose, assert_,
assert_equal, assert_raises, assert_array_equal,
assert_warns)

import pywt
from pywt._extensions._swt import swt_axis
from pywt._extensions._pywt import _check_dtype

# Check that float32 and complex64 are preserved. Other real types get
# converted to float64.
Expand Down Expand Up @@ -387,6 +386,21 @@ def test_iswtn_errors():
assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)


def test_swtn_iswtn_unique_shape_per_axis():
# test case for gh-460
_shape = (1, 48, 32) # unique shape per axis
wav = 'sym2'
max_level = 3
rstate = np.random.RandomState(0)
for shape in permutations(_shape):
# transform only along the non-singleton axes
axes = [ax for ax, s in enumerate(shape) if s != 1]
x = rstate.standard_normal(shape)
c = pywt.swtn(x, wav, max_level, axes=axes)
r = pywt.iswtn(c, wav, axes=axes)
assert_allclose(x, r, rtol=1e-10, atol=1e-10)


def test_per_axis_wavelets():
# tests seperate wavelet for each axis.
rstate = np.random.RandomState(1234)
Expand Down