Motivation
In scientific computing, there are many situations where one has to solve an overdetermined or underdetermined linear system. Libraries like JAX and NumPy have a native least square solver via jax.numpy.linalg.lstsq and np.linalg.lstsq. mlx doesn't support this natively and I think adding this would be a good addition to the current library.
Roadmap
If mlx is interested in supporting this feature, a potential roadmap we can follow is the following:
-
Initial Implementation: As a proof of concept, an initial implementation could be made using the existing
linalg::qr or linalg::svd primitives to establish the API. Once created, we can tailor our efforts towards optimizing performance.
-
CPU Optimization: Transition the CPU backend to utilize LAPACK
gelsd. The current CPU SVD uses the LAPACK routines gesdd with jobz='A', which materializes the full $U$ and $V$ matrice. This is unideal in terms of memory usage when we have largely overdetermined or undetermined systems i.e. our coefficient matrix is a thin rectangle. Utilizing gelsd allows us to solve the system of linear equations without fully materializing the $U$ and $V$ matrix.
-
Optional: An optional but another potential optimization we can make is to default to QR-decomposition based methods when the coefficient matrix is known/assumed to be near full rank.
API Parity
We should aim to emulate thenumpy API as much as possible.
Motivation
In scientific computing, there are many situations where one has to solve an overdetermined or underdetermined linear system. Libraries like JAX and NumPy have a native least square solver via
jax.numpy.linalg.lstsqandnp.linalg.lstsq. mlx doesn't support this natively and I think adding this would be a good addition to the current library.Roadmap
If mlx is interested in supporting this feature, a potential roadmap we can follow is the following:
linalg::qrorlinalg::svdprimitives to establish the API. Once created, we can tailor our efforts towards optimizing performance.gelsd. The current CPU SVD uses the LAPACK routines gesdd withjobz='A', which materializes the fullgelsdallows us to solve the system of linear equations without fully materializing theAPI Parity
We should aim to emulate the
numpyAPI as much as possible.