forked from PtyLab/PtyLab.py
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_base_engine.py
More file actions
140 lines (106 loc) · 5.77 KB
/
Copy pathtest_base_engine.py
File metadata and controls
140 lines (106 loc) · 5.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Tests for PtyLabX.Engines.BaseEngine — core engine functionality."""
import jax
import jax.numpy as jnp
import numpy as np
from numpy.testing import assert_allclose
from scipy.ndimage import fourier_gaussian
from PtyLabX.Engines.BaseEngine import _fourier_gaussian_jax
class TestSpectralPowerCorrection:
"""Tests for vectorized spectral power correction in applyConstraints."""
def test_spectral_power_correction_preserves_shape(self):
"""Vectorized spectral power correction should preserve probe shape."""
probe = jnp.ones((3, 1, 1, 1, 32, 32), dtype=jnp.complex64) * (1 + 0.5j)
spectral_power = jnp.array([0.5, 0.3, 0.2])
maxProbePower = 1.0
norms = jnp.sqrt(jnp.sum(probe * probe.conj(), axis=(1, 2, 3, 4, 5), keepdims=True))
scales = maxProbePower * spectral_power.reshape(-1, 1, 1, 1, 1, 1) / norms
result = probe * scales
assert result.shape == probe.shape
def test_spectral_power_sets_correct_power(self):
"""After correction, each wavelength should have the target spectral power."""
rng = np.random.default_rng(0)
probe = jnp.array(
rng.standard_normal((3, 1, 1, 1, 32, 32)) + 1j * rng.standard_normal((3, 1, 1, 1, 32, 32)),
dtype=jnp.complex64,
)
spectral_power = jnp.array([0.5, 0.3, 0.2])
maxProbePower = 10.0
norms = jnp.sqrt(jnp.sum(probe * probe.conj(), axis=(1, 2, 3, 4, 5), keepdims=True))
scales = maxProbePower * spectral_power.reshape(-1, 1, 1, 1, 1, 1) / norms
result = probe * scales
# Check that the power per wavelength matches target
for wl in range(3):
power = float(jnp.sqrt(jnp.sum(result[wl] * result[wl].conj())).real)
expected = maxProbePower * float(spectral_power[wl])
assert_allclose(power, expected, rtol=1e-5)
class TestWavelengthCoupling:
"""Tests for vectorized wavelength coupling in applyConstraints."""
def test_coupling_boundary_conditions(self):
"""Boundary wavelengths should only couple with their one neighbor."""
probe = jnp.zeros((4, 1, 1, 1, 8, 8), dtype=jnp.complex64)
# Set each wavelength to a distinct value
for i in range(4):
probe = probe.at[i].set((i + 1.0) * jnp.ones((1, 1, 1, 8, 8)))
a = 0.5
shifted_up = jnp.roll(probe, -1, axis=0)
shifted_down = jnp.roll(probe, 1, axis=0)
coupled = (1 - a) * probe + a * (shifted_up + shifted_down) / 2
coupled = coupled.at[0].set((1 - a) * probe[0] + a * probe[1])
coupled = coupled.at[-1].set((1 - a) * probe[-1] + a * probe[-2])
# First wavelength: (1-0.5)*1 + 0.5*2 = 1.5
assert_allclose(float(coupled[0, 0, 0, 0, 0, 0].real), 1.5, atol=1e-5)
# Last wavelength: (1-0.5)*4 + 0.5*3 = 3.5
assert_allclose(float(coupled[-1, 0, 0, 0, 0, 0].real), 3.5, atol=1e-5)
class TestPositionCorrectionVmap:
"""Tests for vmapped position correction."""
def test_vmap_cc_matches_loop(self):
"""Vmapped cross-correlation should match sequential loop."""
rng = np.random.default_rng(42)
Opatch = jnp.array(rng.standard_normal((32, 32)), dtype=jnp.float32)
O_slice = jnp.array(rng.standard_normal((32, 32)), dtype=jnp.float32)
rowShifts = jnp.array([-1, -1, -1, 0, 0, 0, 1, 1, 1])
colShifts = jnp.array([-1, 0, 1, -1, 0, 1, -1, 0, 1])
# Loop version
cc_loop = jnp.zeros((9, 1))
for i in range(9):
shifted = jnp.roll(jnp.roll(Opatch, rowShifts[i], axis=-2), colShifts[i], axis=-1)
cc_loop = cc_loop.at[i].set(jnp.squeeze(jnp.sum(shifted.conj() * O_slice, axis=(-2, -1))))
# Vmap version
def _cc_at_shift(i):
shifted = jnp.roll(jnp.roll(Opatch, rowShifts[i], axis=-2), colShifts[i], axis=-1)
return jnp.squeeze(jnp.sum(shifted.conj() * O_slice, axis=(-2, -1)))
cc_vmap = jax.vmap(_cc_at_shift)(jnp.arange(9)).reshape(-1, 1)
assert_allclose(np.asarray(cc_loop), np.asarray(cc_vmap), atol=1e-5)
class TestFourierGaussianJax:
"""Tests for the pure-JAX Fourier Gaussian replacement."""
def test_matches_scipy_2d(self):
"""Pure JAX implementation should match per-axis scipy fourier_gaussian on 2D input."""
rng = np.random.default_rng(42)
data = rng.standard_normal((64, 64)) + 1j * rng.standard_normal((64, 64))
data = data.astype(np.complex64)
sigma = 3.0
F_field = np.fft.fft2(data)
# scipy reference: apply per-axis (sigma only on spatial dims)
F_scipy = fourier_gaussian(F_field, [sigma, sigma])
# JAX implementation
F_jax = _fourier_gaussian_jax(jnp.array(F_field), sigma)
assert_allclose(np.asarray(F_jax), F_scipy, atol=1e-5)
def test_matches_scipy_6d(self):
"""Should work on 6D arrays, applying only along last 2 (spatial) axes."""
rng = np.random.default_rng(0)
shape = (1, 1, 1, 1, 32, 32)
data = rng.standard_normal(shape) + 1j * rng.standard_normal(shape)
data = data.astype(np.complex64)
sigma = 2.0
F_field = np.fft.fft2(data)
# scipy reference: sigma only on last 2 axes, 0 on batch dims
sigmas = [0] * (len(shape) - 2) + [sigma, sigma]
F_scipy = fourier_gaussian(F_field, sigmas)
F_jax = _fourier_gaussian_jax(jnp.array(F_field), sigma)
assert_allclose(np.asarray(F_jax), F_scipy, atol=1e-5)
def test_identity_at_zero_sigma(self):
"""sigma=0 should return the input unchanged."""
rng = np.random.default_rng(1)
data = jnp.array(rng.standard_normal((16, 16)) + 1j * rng.standard_normal((16, 16)), dtype=jnp.complex64)
result = _fourier_gaussian_jax(data, 0.0)
assert_allclose(np.asarray(result), np.asarray(data), atol=1e-7)