Skip to content

Feature Request: Least Squares Solver #3773

Description

@danlee2002

Motivation

In scientific computing, there are many situations where one has to solve an overdetermined or underdetermined linear system. Libraries like JAX and NumPy have a native least square solver via jax.numpy.linalg.lstsq and np.linalg.lstsq. mlx doesn't support this natively and I think adding this would be a good addition to the current library.

Roadmap

If mlx is interested in supporting this feature, a potential roadmap we can follow is the following:

  1. Initial Implementation: As a proof of concept, an initial implementation could be made using the existing linalg::qr or linalg::svd primitives to establish the API. Once created, we can tailor our efforts towards optimizing performance.
  2. CPU Optimization: Transition the CPU backend to utilize LAPACK gelsd. The current CPU SVD uses the LAPACK routines gesdd with jobz='A', which materializes the full $U$ and $V$ matrice. This is unideal in terms of memory usage when we have largely overdetermined or undetermined systems i.e. our coefficient matrix is a thin rectangle. Utilizing gelsd allows us to solve the system of linear equations without fully materializing the $U$ and $V$ matrix.
  3. Optional: An optional but another potential optimization we can make is to default to QR-decomposition based methods when the coefficient matrix is known/assumed to be near full rank.

API Parity

We should aim to emulate thenumpy API as much as possible.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions