Skip to content

[BUG] Complex vjps are wrong for square/sin/sinh/cosh/tan/tanh/log1p (missing conjugate) #3765

Description

@obchain

Describe the bug

Several holomorphic unary ops return the wrong gradient for complex inputs because their vjp delegates to their jvp. For a holomorphic f the jvp is f'(z) * t, but the vjp must be cotangent * conj(f'(z)) (the convention used by e.g. mx.exp/mx.log); delegating drops the conjugate.

Affected: square, sin, sinh, cosh, tan, tanh, log1p.

To Reproduce

import mlx.core as mx
z = mx.array(2 + 3j, mx.complex64)
print(mx.grad(lambda z: mx.real(mx.square(z)))(z))  # 4+6j  (wrong)
# d Re(z^2) = (2*Re(z), -2*Im(z)) -> 4 - 6j = 2*conj(z)

mx.grad returns 2*z instead of 2*conj(z); the same conjugate is dropped for the other ops.

Expected behavior

The vjp equals cotangent * conj(f'(z)), consistent with mx.exp/mx.log and with finite differences.

Desktop

  • OS: macOS
  • Version: main (0.32.0.dev)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions