Describe the bug
Several holomorphic unary ops return the wrong gradient for complex inputs because their vjp delegates to their jvp. For a holomorphic f the jvp is f'(z) * t, but the vjp must be cotangent * conj(f'(z)) (the convention used by e.g. mx.exp/mx.log); delegating drops the conjugate.
Affected: square, sin, sinh, cosh, tan, tanh, log1p.
To Reproduce
import mlx.core as mx
z = mx.array(2 + 3j, mx.complex64)
print(mx.grad(lambda z: mx.real(mx.square(z)))(z)) # 4+6j (wrong)
# d Re(z^2) = (2*Re(z), -2*Im(z)) -> 4 - 6j = 2*conj(z)
mx.grad returns 2*z instead of 2*conj(z); the same conjugate is dropped for the other ops.
Expected behavior
The vjp equals cotangent * conj(f'(z)), consistent with mx.exp/mx.log and with finite differences.
Desktop
- OS: macOS
- Version: main (0.32.0.dev)
Describe the bug
Several holomorphic unary ops return the wrong gradient for complex inputs because their
vjpdelegates to theirjvp. For a holomorphicfthe jvp isf'(z) * t, but the vjp must becotangent * conj(f'(z))(the convention used by e.g.mx.exp/mx.log); delegating drops the conjugate.Affected:
square,sin,sinh,cosh,tan,tanh,log1p.To Reproduce
mx.gradreturns2*zinstead of2*conj(z); the same conjugate is dropped for the other ops.Expected behavior
The vjp equals
cotangent * conj(f'(z)), consistent withmx.exp/mx.logand with finite differences.Desktop