diff --git a/mlx/backend/cpu/simd/math.h b/mlx/backend/cpu/simd/math.h index f9fc8317a5..78107ad1a5 100644 --- a/mlx/backend/cpu/simd/math.h +++ b/mlx/backend/cpu/simd/math.h @@ -7,6 +7,7 @@ namespace mlx::core::simd { constexpr float inf = std::numeric_limits::infinity(); +constexpr double inf_double = std::numeric_limits::infinity(); /** * Compute exp(x) in an optimizer friendly way as follows: @@ -28,6 +29,35 @@ template Simd exp(Simd in) { if constexpr (is_complex) { return Simd{std::exp(in.value)}; + } else if constexpr (std::is_same_v) { + auto x_init = in; + auto x = x_init * 1.4426950408889634; // multiply with log_2(e) + Simd ipart, fpart; + ipart = floor(x + 0.5); + fpart = x - ipart; + + x = 6.77872635482254254e-14; + x = fma(x, fpart, 1.36914888539041241e-12); + x = fma(x, fpart, 2.56784359934881958e-11); + x = fma(x, fpart, 4.44553827187081007e-10); + x = fma(x, fpart, 7.05491162080112088e-09); + x = fma(x, fpart, 1.01780860092396960e-07); + x = fma(x, fpart, 1.32154867901443053e-06); + x = fma(x, fpart, 1.52527338040598377e-05); + x = fma(x, fpart, 1.54035303933816061e-04); + x = fma(x, fpart, 1.33335581464284411e-03); + x = fma(x, fpart, 9.61812910762847688e-03); + x = fma(x, fpart, 5.55041086648215762e-02); + x = fma(x, fpart, 2.40226506959100694e-01); + x = fma(x, fpart, 6.93147180559945286e-01); + x = fma(x, fpart, 1.00000000000000000e+00); + + Simd epart = (Simd(ipart) + 1023) << 52; + + auto result = select(isnan(x_init), x_init, (*(Simd*)&epart) * x); + result = select(x_init > 709.0, Simd(inf_double), result); + result = select(x_init < -708.0, Simd(0), result); + return result; } else { Simd x_init = in; auto x = x_init * 1.442695f; // multiply with log_2(e) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 81b10578cb..0095ed2f18 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1521,6 +1521,15 @@ TEST_CASE("test arithmetic unary ops") { x = array(2.0); CHECK_EQ(exp(x).item(), doctest::Approx(std::exp(2.0f))); + x = array({0.1, 0.5, 0.9, 1.25}, float64); + auto exp_x = exp(x); + CHECK_EQ(exp_x.dtype(), float64); + for (int i = 0; i < x.size(); ++i) { + auto actual = take(exp_x, array(i)).item(); + auto expected = std::exp(take(x, array(i)).item()); + CHECK(actual == doctest::Approx(expected).epsilon(1e-12)); + } + CHECK(array_equal(exp(array({})), array({})).item()); x = array(neginf);