Title
value_and_grad fails when function arguments contain native Python numbers
Description
Problem
In MLX, mlx.core.value_and_grad expects all inputs (and nested structures) to be mlx.core.array . Passing native Python
numeric types ( float , int ) raises ValueError .
Reproduction Code
import mlx.core as mx
# Simple function
def f(x):
return x * 2
# Wrap with value_and_grad
fn = mx.value_and_grad(f)
# Pass native Python float
fn(3.0)
Traceback
ValueError: [tree_flatten] The argument should contain only arrays
Comparison with JAX
In JAX, jax.value_and_grad supports native Python numeric types via implicit coercion/promotion:
import jax
fn = jax.value_and_grad(lambda x: x * 2)
print(fn(3.0))
# Output: (Array(6., dtype=float32, weak_type=True), Array(2., dtype=float32, weak_type=True))
Proposed Solution
Modify mlx.core.value_and_grad (and other transform API entry points) to automatically convert native Python numbers to
mlx.core.array before performing tree_flatten .
Title
value_and_grad fails when function arguments contain native Python numbers
Description
Problem
In MLX, mlx.core.value_and_grad expects all inputs (and nested structures) to be mlx.core.array . Passing native Python
numeric types ( float , int ) raises ValueError .
Reproduction Code
Traceback
Comparison with JAX
In JAX, jax.value_and_grad supports native Python numeric types via implicit coercion/promotion:
Proposed Solution
Modify mlx.core.value_and_grad (and other transform API entry points) to automatically convert native Python numbers to
mlx.core.array before performing tree_flatten .