diff --git a/mlx/backend/cpu/simd/base_simd.h b/mlx/backend/cpu/simd/base_simd.h index 775f5dfd10..9529312c82 100644 --- a/mlx/backend/cpu/simd/base_simd.h +++ b/mlx/backend/cpu/simd/base_simd.h @@ -99,7 +99,6 @@ DEFAULT_UNARY(atanh, std::atanh) DEFAULT_UNARY(ceil, std::ceil) DEFAULT_UNARY(conj, std::conj) DEFAULT_UNARY(cosh, std::cosh) -DEFAULT_UNARY(expm1, std::expm1) DEFAULT_UNARY(floor, std::floor) DEFAULT_UNARY(log, std::log) DEFAULT_UNARY(log10, std::log10) @@ -130,6 +129,16 @@ Simd log1p(Simd in) { } } +template +Simd expm1(Simd in) { + if constexpr (is_complex) { + // std::expm1 has no complex overload; use the defining identity. + return Simd{std::exp(in.value) - T(1)}; + } else { + return Simd{std::expm1(in.value)}; + } +} + template Simd log2(Simd in) { if constexpr (is_complex) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 660ea76c8d..476173b2b7 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1038,6 +1038,14 @@ def test_expm1(self): np.seterr(over=errs["over"]) self.assertTrue(np.allclose(result, expected, rtol=1e-3, atol=1e-4)) + # Complex inputs: expm1(z) = exp(z) - 1, not expm1(Re(z)). + c = np.array( + [0.5 + 0.7j, -1.0 + 2.0j, 0.0 + 1.0j, 2.0 - 0.5j, 0.01 + 0.02j], + dtype=np.complex64, + ) + result = mx.expm1(mx.array(c)) + self.assertTrue(np.allclose(result, np.exp(c) - 1, rtol=1e-4, atol=1e-5)) + def test_erf(self): inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0] x = mx.array(inputs)