Skip to content

[Feature]: Supporting normal python datatypes in the grad method obtained from mlx.core.value_and_grad #3774

Description

@aaishwarymishra

Title

value_and_grad fails when function arguments contain native Python numbers

Description

Problem

In MLX, mlx.core.value_and_grad expects all inputs (and nested structures) to be mlx.core.array . Passing native Python
numeric types ( float , int ) raises ValueError .

Reproduction Code

import mlx.core as mx

# Simple function
def f(x):
    return x * 2

# Wrap with value_and_grad
fn = mx.value_and_grad(f)

# Pass native Python float
fn(3.0)

Traceback

ValueError: [tree_flatten] The argument should contain only arrays

Comparison with JAX

In JAX, jax.value_and_grad supports native Python numeric types via implicit coercion/promotion:

import jax

fn = jax.value_and_grad(lambda x: x * 2)
print(fn(3.0))
# Output: (Array(6., dtype=float32, weak_type=True), Array(2., dtype=float32, weak_type=True))				

Proposed Solution

Modify mlx.core.value_and_grad (and other transform API entry points) to automatically convert native Python numbers to
mlx.core.array before performing tree_flatten .

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions