From c77b70ae60eb941f9d648de5aa2366d48a2590df Mon Sep 17 00:00:00 2001 From: Shubh Date: Sat, 20 Jun 2026 04:30:24 +0530 Subject: [PATCH 1/4] Add math mode option for custom Metal kernels --- docs/src/dev/custom_metal_kernels.rst | 21 +++++++++++++++++ mlx/backend/common/metal_kernel.cpp | 6 +++-- mlx/backend/metal/custom_kernel.cpp | 32 ++++++++++++++++++++++---- mlx/backend/metal/device.cpp | 27 +++++++++++++++++++--- mlx/backend/metal/device.h | 10 +++++++- mlx/fast.h | 9 +++++++- mlx/fast_primitives.h | 10 +++++--- python/src/fast.cpp | 24 +++++++++++++++++-- python/tests/test_export_import.py | 32 ++++++++++++++++++++++++++ python/tests/test_fast.py | 33 +++++++++++++++++++++++++++ 10 files changed, 187 insertions(+), 17 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index f5881b5c17..9629475619 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -50,6 +50,27 @@ JIT compiled. To reduce the overhead from that, build the kernel once with Only pass the body of the Metal kernel in ``source``. The function signature is generated automatically. +Math Mode +--------- + +By default :func:`fast.metal_kernel` compiles kernels with ``math_mode="safe"`` +so special values follow IEEE behavior, for example ``exp(-inf) == 0``. This is +important for kernels such as masked softmax where causal or sliding-window +masks depend on exponentiating ``-inf``. + +If your kernel does not rely on these edge cases, you can opt in to less strict +math with ``math_mode="relaxed"`` or ``math_mode="fast"``: + +.. code-block:: python + + kernel = mx.fast.metal_kernel( + name="my_kernel", + input_names=["x"], + output_names=["y"], + source=source, + math_mode="relaxed", + ) + The full function signature will be generated using: * The shapes/dtypes of ``inputs`` diff --git a/mlx/backend/common/metal_kernel.cpp b/mlx/backend/common/metal_kernel.cpp index 691feb554a..af09cacd3b 100644 --- a/mlx/backend/common/metal_kernel.cpp +++ b/mlx/backend/common/metal_kernel.cpp @@ -206,7 +206,8 @@ CustomKernelFunction metal_kernel( const std::string& source, const std::string& header /* = "" */, bool ensure_row_contiguous /* = true */, - bool atomic_outputs /* = false */) { + bool atomic_outputs /* = false */, + MetalKernelMathMode math_mode /* = MetalKernelMathMode::Safe */) { if (output_names.empty()) { throw std::invalid_argument( "[metal_kernel] Must specify at least one output."); @@ -360,7 +361,8 @@ CustomKernelFunction metal_kernel( init_value, std::vector{}, false, - 0), + 0, + static_cast(math_mode)), std::move(inputs)); }; } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 0ba491f4ff..23bb8891bd 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -8,7 +8,8 @@ namespace mlx::core::fast { struct CustomKernelCache { - std::unordered_map libraries; + std::unordered_map>> + libraries; }; static CustomKernelCache& cache() { @@ -16,6 +17,22 @@ static CustomKernelCache& cache() { return cache_; }; +std::optional to_mtl_math_mode(std::optional math_mode) { + if (!math_mode) { + return std::nullopt; + } + switch (*math_mode) { + case static_cast(MetalKernelMathMode::Safe): + return MTL::MathModeSafe; + case static_cast(MetalKernelMathMode::Relaxed): + return MTL::MathModeRelaxed; + case static_cast(MetalKernelMathMode::Fast): + return MTL::MathModeFast; + default: + throw std::invalid_argument("[metal_kernel] Invalid Metal math mode."); + } +} + void CustomKernel::eval_gpu( const std::vector& inputs, std::vector& outputs) { @@ -58,17 +75,22 @@ void CustomKernel::eval_gpu( auto& kernel_cache = cache(); if (auto it = kernel_cache.libraries.find(name_); it != kernel_cache.libraries.end()) { - if (it->second != source_) { + if (it->second.first != source_ || + it->second.second != metal_math_mode_) { auto& d = metal::device(s.device); d.clear_library(name_); - it->second = source_; + it->second = {source_, metal_math_mode_}; } } else { - kernel_cache.libraries.emplace(name_, source_); + kernel_cache.libraries.emplace( + name_, std::make_pair(source_, metal_math_mode_)); } } - auto lib = d.get_library(name_, [this] { return metal::utils() + source_; }); + auto lib = d.get_library( + name_, + [this] { return metal::utils() + source_; }, + to_mtl_math_mode(metal_math_mode_)); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7658ce5f5c..e915bbec8d 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -606,7 +606,8 @@ MTL::Library* Device::get_library( } NS::SharedPtr Device::build_library_( - const std::string& source_string) { + const std::string& source_string, + std::optional math_mode) { auto pool = new_scoped_memory_pool(); auto ns_code = @@ -614,7 +615,20 @@ NS::SharedPtr Device::build_library_( NS::Error* error = nullptr; auto options = MTL::CompileOptions::alloc()->init()->autorelease(); - options->setFastMathEnabled(false); + if (math_mode) { + if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) { + options->setMathMode(*math_mode); + } else { + if (*math_mode == MTL::MathModeRelaxed) { + throw std::runtime_error( + "[metal::Device] Metal math mode `relaxed` requires macOS 15, " + "iOS 18, tvOS 18, or visionOS 2."); + } + options->setFastMathEnabled(*math_mode == MTL::MathModeFast); + } + } else { + options->setFastMathEnabled(false); + } options->setLanguageVersion(get_metal_version()); #ifndef NDEBUG if (options->languageVersion() >= MTL::LanguageVersion3_2) { @@ -756,6 +770,13 @@ NS::SharedPtr Device::get_kernel_( MTL::Library* Device::get_library( const std::string& name, const std::function& builder) { + return get_library(name, builder, std::nullopt); +} + +MTL::Library* Device::get_library( + const std::string& name, + const std::function& builder, + std::optional math_mode) { { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { @@ -768,7 +789,7 @@ MTL::Library* Device::get_library( return it->second.get(); } - auto mtl_lib = build_library_(builder()); + auto mtl_lib = build_library_(builder(), math_mode); library_map_.insert({name, mtl_lib}); return mtl_lib.get(); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index bed0cd636e..d9cb910c4e 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -170,6 +171,11 @@ class MLX_API Device { const std::string& name, const std::function& builder); + MTL::Library* get_library( + const std::string& name, + const std::function& builder, + std::optional math_mode); + void clear_library(const std::string& name); MTL::ComputePipelineState* get_kernel( @@ -190,7 +196,9 @@ class MLX_API Device { } private: - NS::SharedPtr build_library_(const std::string& source_string); + NS::SharedPtr build_library_( + const std::string& source_string, + std::optional math_mode = std::nullopt); NS::SharedPtr get_function_( const std::string& name, diff --git a/mlx/fast.h b/mlx/fast.h index 1183aba8fe..58b4ad8a7d 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -57,6 +57,12 @@ MLX_API array scaled_dot_product_attention( using TemplateArg = std::variant; using ScalarArg = std::variant; +enum class MetalKernelMathMode { + Safe = 0, + Relaxed = 1, + Fast = 2, +}; + using CustomKernelFunction = std::function( const std::vector&, const std::vector&, @@ -75,7 +81,8 @@ MLX_API CustomKernelFunction metal_kernel( const std::string& source, const std::string& header = "", bool ensure_row_contiguous = true, - bool atomic_outputs = false); + bool atomic_outputs = false, + MetalKernelMathMode math_mode = MetalKernelMathMode::Safe); MLX_API CustomKernelFunction cuda_kernel( const std::string& name, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..4ffcfd3d56 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -375,7 +375,8 @@ class CustomKernel : public Primitive { std::optional init_value, std::vector scalar_arguments, bool is_precompiled, - int shared_memory) + int shared_memory, + std::optional metal_math_mode = std::nullopt) : Primitive(stream), name_(std::move(name)), source_(std::move(source)), @@ -386,7 +387,8 @@ class CustomKernel : public Primitive { init_value_(init_value), scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), - shared_memory_(shared_memory) {} + shared_memory_(shared_memory), + metal_math_mode_(metal_math_mode) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -408,7 +410,8 @@ class CustomKernel : public Primitive { init_value_, scalar_arguments_, is_precompiled_, - shared_memory_); + shared_memory_, + metal_math_mode_); } private: @@ -422,6 +425,7 @@ class CustomKernel : public Primitive { std::vector scalar_arguments_; bool is_precompiled_; int shared_memory_; + std::optional metal_math_mode_; }; } // namespace mlx::core::fast diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..fc8af5a9da 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -75,6 +75,19 @@ struct PyCustomKernelFunction { const char* tag_; }; +mx::fast::MetalKernelMathMode parse_metal_math_mode( + const std::string& math_mode) { + if (math_mode == "safe") { + return mx::fast::MetalKernelMathMode::Safe; + } else if (math_mode == "relaxed") { + return mx::fast::MetalKernelMathMode::Relaxed; + } else if (math_mode == "fast") { + return mx::fast::MetalKernelMathMode::Fast; + } + throw std::invalid_argument( + "[metal_kernel] Expected math_mode to be 'safe', 'relaxed', or 'fast'."); +} + } // namespace void init_fast(nb::module_& parent_module) { @@ -304,7 +317,8 @@ void init_fast(nb::module_& parent_module) { const std::string& source, const std::string& header, bool ensure_row_contiguous, - bool atomic_outputs) { + bool atomic_outputs, + const std::string& math_mode) { auto kernel = mx::fast::metal_kernel( name, input_names, @@ -312,7 +326,8 @@ void init_fast(nb::module_& parent_module) { source, header, ensure_row_contiguous, - atomic_outputs); + atomic_outputs, + parse_metal_math_mode(math_mode)); return nb::cpp_function( PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"), nb::kw_only(), @@ -356,6 +371,7 @@ void init_fast(nb::module_& parent_module) { "header"_a = "", "ensure_row_contiguous"_a = true, "atomic_outputs"_a = false, + "math_mode"_a = "safe", R"pbdoc( A jit-compiled custom Metal kernel defined from a source string. @@ -376,6 +392,10 @@ void init_fast(nb::module_& parent_module) { before the kernel runs. Default: ``True``. atomic_outputs (bool): Whether to use atomic outputs in the function signature e.g. ``device atomic``. Default: ``False``. + math_mode (str): The Metal math mode to compile the kernel with: + ``"safe"``, ``"relaxed"``, or ``"fast"``. ``"safe"`` preserves + IEEE behavior for special values such as ``exp(-inf) == 0``. + Default: ``"safe"``. Returns: Callable ``metal_kernel``. diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 87e7b31ced..971f8fc264 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -638,6 +638,38 @@ def call_cpu(a): with self.assertRaisesRegex(RuntimeError, "No Metal back-end"): mx.eval(mx.compile(call)(a)) + def test_export_custom_metal_kernel_with_math_mode(self): + source = """ + uint elem = thread_position_in_grid.x; + out[elem] = metal::exp(a[elem]); + """ + kernel = mx.fast.metal_kernel( + name="math_mode_export", + input_names=["a"], + output_names=["out"], + source=source, + math_mode="safe", + ) + + def call(a): + return kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(min(a.size, 256), 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + + a = mx.array([-float("inf"), 0.0]) + path = os.path.join(self.test_dir, "metal_kernel_math_mode.mlxfn") + mx.export_function(path, call, a) + self.assertTrue(os.path.exists(path)) + + if mx.metal.is_available(): + imported = mx.import_function(path) + self.assertTrue(mx.array_equal(imported(a)[0], call(a))) + def test_export_import_multi_with_constants(self): path = os.path.join(self.test_dir, "fn.mlxfn") diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index b1c84c987d..69e07552e9 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -1026,6 +1026,39 @@ def call_kernel(a: mx.array, source): out = call_kernel(a, source) self.assertTrue(mx.array_equal(out, mx.ones_like(out))) + def test_custom_metal_kernel_invalid_math_mode(self): + with self.assertRaises(ValueError): + mx.fast.metal_kernel( + name="invalid_math_mode", + input_names=["inp"], + output_names=["out"], + source="out[0] = inp[0];", + math_mode="precise", + ) + + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_custom_metal_kernel_safe_math_mode(self): + kernel = mx.fast.metal_kernel( + name="safe_math_mode", + input_names=["inp"], + output_names=["out"], + source=""" + uint elem = thread_position_in_grid.x; + out[elem] = metal::exp(inp[elem]); + """, + math_mode="safe", + ) + a = mx.array([-float("inf"), 0.0], dtype=mx.float32) + out = kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(a.size, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + self.assertTrue(mx.array_equal(out, mx.array([0.0, 1.0]))) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_mixed_dtypes(self): # Calling the same kernel with different input dtypes in a single From 58cf13a410b6965e6f23439ba0e0188038840133 Mon Sep 17 00:00:00 2001 From: Shubh Date: Sat, 20 Jun 2026 13:09:06 +0530 Subject: [PATCH 2/4] Refactor Metal compile options --- docs/src/dev/custom_metal_kernels.rst | 14 +++++---- mlx/backend/common/metal_kernel.cpp | 4 +-- mlx/backend/metal/custom_kernel.cpp | 3 +- mlx/backend/metal/device.cpp | 17 ++++++----- mlx/backend/metal/device.h | 8 +++-- mlx/fast.h | 6 +++- python/src/fast.cpp | 44 ++++++++++++++++++++++----- python/tests/test_export_import.py | 2 +- python/tests/test_fast.py | 14 +++++++-- 9 files changed, 82 insertions(+), 30 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 9629475619..ce5880952f 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -53,13 +53,15 @@ JIT compiled. To reduce the overhead from that, build the kernel once with Math Mode --------- -By default :func:`fast.metal_kernel` compiles kernels with ``math_mode="safe"`` -so special values follow IEEE behavior, for example ``exp(-inf) == 0``. This is -important for kernels such as masked softmax where causal or sliding-window -masks depend on exponentiating ``-inf``. +By default :func:`fast.metal_kernel` compiles kernels with +``compile_options={"math_mode": "safe"}`` so special values follow IEEE +behavior, for example ``exp(-inf) == 0``. This is important for kernels such as +masked softmax where causal or sliding-window masks depend on exponentiating +``-inf``. If your kernel does not rely on these edge cases, you can opt in to less strict -math with ``math_mode="relaxed"`` or ``math_mode="fast"``: +math with ``compile_options={"math_mode": "relaxed"}`` or +``compile_options={"math_mode": "fast"}``: .. code-block:: python @@ -68,7 +70,7 @@ math with ``math_mode="relaxed"`` or ``math_mode="fast"``: input_names=["x"], output_names=["y"], source=source, - math_mode="relaxed", + compile_options={"math_mode": "relaxed"}, ) The full function signature will be generated using: diff --git a/mlx/backend/common/metal_kernel.cpp b/mlx/backend/common/metal_kernel.cpp index af09cacd3b..edbf75c46e 100644 --- a/mlx/backend/common/metal_kernel.cpp +++ b/mlx/backend/common/metal_kernel.cpp @@ -207,7 +207,7 @@ CustomKernelFunction metal_kernel( const std::string& header /* = "" */, bool ensure_row_contiguous /* = true */, bool atomic_outputs /* = false */, - MetalKernelMathMode math_mode /* = MetalKernelMathMode::Safe */) { + CompileOptions compile_options /* = {} */) { if (output_names.empty()) { throw std::invalid_argument( "[metal_kernel] Must specify at least one output."); @@ -362,7 +362,7 @@ CustomKernelFunction metal_kernel( std::vector{}, false, 0, - static_cast(math_mode)), + static_cast(compile_options.math_mode)), std::move(inputs)); }; } diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 23bb8891bd..1e47428088 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/utils.h" +#include "mlx/fast.h" #include "mlx/fast_primitives.h" namespace mlx::core::fast { @@ -90,7 +91,7 @@ void CustomKernel::eval_gpu( auto lib = d.get_library( name_, [this] { return metal::utils() + source_; }, - to_mtl_math_mode(metal_math_mode_)); + metal::CompileOptions{to_mtl_math_mode(metal_math_mode_)}); auto kernel = d.get_kernel(name_, lib); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e915bbec8d..de7ad6d159 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -607,7 +607,7 @@ MTL::Library* Device::get_library( NS::SharedPtr Device::build_library_( const std::string& source_string, - std::optional math_mode) { + CompileOptions compile_options) { auto pool = new_scoped_memory_pool(); auto ns_code = @@ -615,16 +615,17 @@ NS::SharedPtr Device::build_library_( NS::Error* error = nullptr; auto options = MTL::CompileOptions::alloc()->init()->autorelease(); - if (math_mode) { + if (compile_options.math_mode) { + auto math_mode = *compile_options.math_mode; if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) { - options->setMathMode(*math_mode); + options->setMathMode(math_mode); } else { - if (*math_mode == MTL::MathModeRelaxed) { + if (math_mode == MTL::MathModeRelaxed) { throw std::runtime_error( "[metal::Device] Metal math mode `relaxed` requires macOS 15, " "iOS 18, tvOS 18, or visionOS 2."); } - options->setFastMathEnabled(*math_mode == MTL::MathModeFast); + options->setFastMathEnabled(math_mode == MTL::MathModeFast); } } else { options->setFastMathEnabled(false); @@ -770,13 +771,13 @@ NS::SharedPtr Device::get_kernel_( MTL::Library* Device::get_library( const std::string& name, const std::function& builder) { - return get_library(name, builder, std::nullopt); + return get_library(name, builder, {}); } MTL::Library* Device::get_library( const std::string& name, const std::function& builder, - std::optional math_mode) { + CompileOptions compile_options) { { std::shared_lock rlock(library_mtx_); if (auto it = library_map_.find(name); it != library_map_.end()) { @@ -789,7 +790,7 @@ MTL::Library* Device::get_library( return it->second.get(); } - auto mtl_lib = build_library_(builder(), math_mode); + auto mtl_lib = build_library_(builder(), compile_options); library_map_.insert({name, mtl_lib}); return mtl_lib.get(); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index d9cb910c4e..e04cb24228 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -20,6 +20,10 @@ namespace mlx::core::metal { using MTLFCList = std::vector>; +struct CompileOptions { + std::optional math_mode = std::nullopt; +}; + class Device; class EventImpl; @@ -174,7 +178,7 @@ class MLX_API Device { MTL::Library* get_library( const std::string& name, const std::function& builder, - std::optional math_mode); + CompileOptions compile_options); void clear_library(const std::string& name); @@ -198,7 +202,7 @@ class MLX_API Device { private: NS::SharedPtr build_library_( const std::string& source_string, - std::optional math_mode = std::nullopt); + CompileOptions compile_options = {}); NS::SharedPtr get_function_( const std::string& name, diff --git a/mlx/fast.h b/mlx/fast.h index 58b4ad8a7d..2e70ba8ac5 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -63,6 +63,10 @@ enum class MetalKernelMathMode { Fast = 2, }; +struct CompileOptions { + MetalKernelMathMode math_mode = MetalKernelMathMode::Safe; +}; + using CustomKernelFunction = std::function( const std::vector&, const std::vector&, @@ -82,7 +86,7 @@ MLX_API CustomKernelFunction metal_kernel( const std::string& header = "", bool ensure_row_contiguous = true, bool atomic_outputs = false, - MetalKernelMathMode math_mode = MetalKernelMathMode::Safe); + CompileOptions compile_options = {}); MLX_API CustomKernelFunction cuda_kernel( const std::string& name, diff --git a/python/src/fast.cpp b/python/src/fast.cpp index fc8af5a9da..b24732fc06 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -88,6 +88,34 @@ mx::fast::MetalKernelMathMode parse_metal_math_mode( "[metal_kernel] Expected math_mode to be 'safe', 'relaxed', or 'fast'."); } +mx::fast::CompileOptions parse_compile_options( + const nb::object& compile_options) { + mx::fast::CompileOptions options; + + if (compile_options.is_none()) { + return options; + } + + if (!nb::isinstance(compile_options)) { + throw std::invalid_argument( + "[metal_kernel] Expected `compile_options` to be a dict."); + } + + nb::dict dict = nb::cast(compile_options); + for (auto [key, value] : dict) { + auto key_str = nb::cast(key); + if (key_str == "math_mode") { + options.math_mode = + parse_metal_math_mode(nb::cast(value)); + } else { + std::ostringstream msg; + msg << "[metal_kernel] Unknown compile option `" << key_str << "`."; + throw std::invalid_argument(msg.str()); + } + } + return options; +} + } // namespace void init_fast(nb::module_& parent_module) { @@ -318,7 +346,7 @@ void init_fast(nb::module_& parent_module) { const std::string& header, bool ensure_row_contiguous, bool atomic_outputs, - const std::string& math_mode) { + const nb::object& compile_options) { auto kernel = mx::fast::metal_kernel( name, input_names, @@ -327,7 +355,7 @@ void init_fast(nb::module_& parent_module) { header, ensure_row_contiguous, atomic_outputs, - parse_metal_math_mode(math_mode)); + parse_compile_options(compile_options)); return nb::cpp_function( PyCustomKernelFunction(std::move(kernel), "[metal_kernel]"), nb::kw_only(), @@ -371,7 +399,7 @@ void init_fast(nb::module_& parent_module) { "header"_a = "", "ensure_row_contiguous"_a = true, "atomic_outputs"_a = false, - "math_mode"_a = "safe", + "compile_options"_a = nb::none(), R"pbdoc( A jit-compiled custom Metal kernel defined from a source string. @@ -392,10 +420,12 @@ void init_fast(nb::module_& parent_module) { before the kernel runs. Default: ``True``. atomic_outputs (bool): Whether to use atomic outputs in the function signature e.g. ``device atomic``. Default: ``False``. - math_mode (str): The Metal math mode to compile the kernel with: - ``"safe"``, ``"relaxed"``, or ``"fast"``. ``"safe"`` preserves - IEEE behavior for special values such as ``exp(-inf) == 0``. - Default: ``"safe"``. + compile_options (dict, optional): Options to compile the Metal kernel + with. Supported options: + + * ``"math_mode"``: The Metal math mode: ``"safe"``, ``"relaxed"``, + or ``"fast"``. ``"safe"`` preserves IEEE behavior for special + values such as ``exp(-inf) == 0``. Default: ``"safe"``. Returns: Callable ``metal_kernel``. diff --git a/python/tests/test_export_import.py b/python/tests/test_export_import.py index 971f8fc264..c45dd4a606 100644 --- a/python/tests/test_export_import.py +++ b/python/tests/test_export_import.py @@ -648,7 +648,7 @@ def test_export_custom_metal_kernel_with_math_mode(self): input_names=["a"], output_names=["out"], source=source, - math_mode="safe", + compile_options={"math_mode": "safe"}, ) def call(a): diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 69e07552e9..e2baf93f95 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -1033,7 +1033,17 @@ def test_custom_metal_kernel_invalid_math_mode(self): input_names=["inp"], output_names=["out"], source="out[0] = inp[0];", - math_mode="precise", + compile_options={"math_mode": "precise"}, + ) + + def test_custom_metal_kernel_invalid_compile_options(self): + with self.assertRaises(ValueError): + mx.fast.metal_kernel( + name="invalid_compile_options", + input_names=["inp"], + output_names=["out"], + source="out[0] = inp[0];", + compile_options={"unknown": "value"}, ) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") @@ -1046,7 +1056,7 @@ def test_custom_metal_kernel_safe_math_mode(self): uint elem = thread_position_in_grid.x; out[elem] = metal::exp(inp[elem]); """, - math_mode="safe", + compile_options={"math_mode": "safe"}, ) a = mx.array([-float("inf"), 0.0], dtype=mx.float32) out = kernel( From 7477e99343822b3dc245a7df86547e48c9dd8e37 Mon Sep 17 00:00:00 2001 From: Shubh Date: Sat, 20 Jun 2026 16:33:08 +0530 Subject: [PATCH 3/4] Apply clang-format to fast bindings --- python/src/fast.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/src/fast.cpp b/python/src/fast.cpp index b24732fc06..af764ad33f 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -105,8 +105,7 @@ mx::fast::CompileOptions parse_compile_options( for (auto [key, value] : dict) { auto key_str = nb::cast(key); if (key_str == "math_mode") { - options.math_mode = - parse_metal_math_mode(nb::cast(value)); + options.math_mode = parse_metal_math_mode(nb::cast(value)); } else { std::ostringstream msg; msg << "[metal_kernel] Unknown compile option `" << key_str << "`."; From 85a818174990c33dad3451debe87ec33c2290c66 Mon Sep 17 00:00:00 2001 From: Shubh Date: Sun, 28 Jun 2026 06:25:40 +0530 Subject: [PATCH 4/4] Test math mode via __FAST_MATH__ instead of exp(-inf) exp(-inf) agrees across math modes, so the old test passed regardless of the selected mode. Branch on the Metal compiler's __FAST_MATH__ macro instead, which is only defined under fast math, so the test fails if the math mode is not forwarded to the compiler. Reuse one kernel name across modes to exercise the library-cache rebuild path. --- python/tests/test_fast.py | 62 ++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index e2baf93f95..3a28a2c484 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -1047,27 +1047,47 @@ def test_custom_metal_kernel_invalid_compile_options(self): ) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") - def test_custom_metal_kernel_safe_math_mode(self): - kernel = mx.fast.metal_kernel( - name="safe_math_mode", - input_names=["inp"], - output_names=["out"], - source=""" - uint elem = thread_position_in_grid.x; - out[elem] = metal::exp(inp[elem]); - """, - compile_options={"math_mode": "safe"}, - ) - a = mx.array([-float("inf"), 0.0], dtype=mx.float32) - out = kernel( - inputs=[a], - grid=(a.size, 1, 1), - threadgroup=(a.size, 1, 1), - output_shapes=[a.shape], - output_dtypes=[a.dtype], - stream=mx.gpu, - )[0] - self.assertTrue(mx.array_equal(out, mx.array([0.0, 1.0]))) + def test_custom_metal_kernel_math_mode(self): + # Numerical special cases such as exp(-inf) can agree between math + # modes, so they don't reliably detect whether the mode was applied. + # Branch on the compiler's __FAST_MATH__ macro instead: it is defined + # only when fast math is enabled, so the test fails if the selected + # math mode is not forwarded to the Metal compiler. + source = """ + uint elem = thread_position_in_grid.x; + #if defined(__FAST_MATH__) && __FAST_MATH__ + out[elem] = 1.0f; + #else + out[elem] = 0.0f; + #endif + """ + + a = mx.zeros((4,), dtype=mx.float32) + expected = { + "safe": mx.zeros_like(a), + "fast": mx.ones_like(a), + } + + # Reuse the same kernel name across modes so the library cache is forced + # to rebuild when the math mode changes, guarding against a stale build + # being returned for a different mode. + for mode, expected_out in expected.items(): + kernel = mx.fast.metal_kernel( + name="math_mode", + input_names=["inp"], + output_names=["out"], + source=source, + compile_options={"math_mode": mode}, + ) + out = kernel( + inputs=[a], + grid=(a.size, 1, 1), + threadgroup=(a.size, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + stream=mx.gpu, + )[0] + self.assertTrue(mx.array_equal(out, expected_out)) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") def test_custom_kernel_mixed_dtypes(self):