Skip to content

Conversation

@yukiu00
Copy link
Contributor

@yukiu00 yukiu00 commented Feb 4, 2026

Add mHC fused kernels + LigerMHC API + benchmarks

Reference Issue

Summary

This PR adds an opt-in, paper-aligned mHC implementation to Liger-Kernel: fused Triton kernels, a LigerMHC module, functional APIs, tests, and benchmarks.
No existing default patching behavior is changed.

Background (Paper)

mHC: Manifold-Constrained Hyper-Connections (arXiv:2512.24880v2)
https://arxiv.org/abs/2512.24880

Key idea: constrain H_res via Sinkhorn-Knopp onto the doubly-stochastic set (Birkhoff polytope), restoring identity-mapping stability while preserving multi-stream residual benefits. The paper also emphasizes fused kernels + mixed precision + recompute (Sec. 4.3.1, Eq.(14)–(19)).

What’s included

  • Triton mHC kernels (coeffs / Sinkhorn / apply; fwd + bwd).
  • API: LigerMHC + liger_mhc_* functional APIs (Liger naming).
  • allow_fp32 opt-in (default remains BF16/FP16 mixed precision; intended for specific/debug use cases).
  • Benchmarks: benchmark/scripts/benchmark_mhc_lm.py
  • Tests: ops correctness, transformer-level tests, convergence test.

Benchmarks (RTX 3090, BF16, B=2, T=256, n(HC)=4, layers=2, heads=8, vocab=4096)

see #1066 (comment)

Out of scope

  • Regarding the recomputation strategy mentioned in the paper (Section 4.3.2): The block-wise recomputation ($L_r$ layers) is out of scope for the Liger-Kernel. Users can achieve the memory savings described in the paper by simply applying torch.utils.checkpoint to groups of these mHC layers in their training loop.
  • DualPipe schedule / distributed pipeline optimization (paper Sec. 4.3.3)

- Add mHC kernels and APIs
- Provide reference implementations for tests and benchmarks
- Add/adjust tests, tolerances, and benchmarks
- Document memory trade-offs, usage notes, and options
@yukiu00 yukiu00 marked this pull request as ready for review February 4, 2026 09:50
@yukiu00 yukiu00 changed the title Add mHC fused kernels + LigerMHC API + benchmarks (paper-aligned) Add mHC fused kernels + LigerMHC API + benchmarks Feb 4, 2026
Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work🚀 Looking forward to have mHC in Liger Kernel!

I've only skimmed through the code, most comments are about code structure. I'll have a thorough review on implementation tomorrow. Thanks for your patience!

Comment on lines 13 to 23
def _time_loop(fn, iters=200, warmup=50) -> float:
torch.cuda.synchronize()
for _ in range(warmup):
fn()
torch.cuda.synchronize()
t0 = time.time()
for _ in range(iters):
fn()
torch.cuda.synchronize()
t1 = time.time()
return (t1 - t0) * 1e3 / iters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we prefer using triton.testing.do_bench() to bench our kernels. refer to other benchmark scripts for concrete examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. I switched benchmark_mhc.py to use triton.testing.do_bench() with QUANTILES for all forward/backward measurements.

Comment on lines 26 to 30
def _peak_bytes(fn) -> int:
torch.cuda.reset_peak_memory_stats()
fn()
torch.cuda.synchronize()
return int(torch.cuda.max_memory_allocated())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _test_memory(

we have utils._test_memory() for checking memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Memory measurement now uses utils._test_memory() consistently.


import torch

from utils import mhc_coeffs_ref
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep one reference only for single source of truth. Instead of writing another one in benchmark/scripts/utils.py, reuse the one from test.transformers.test_mhc to avoid inconsistent reference in future update.

Add root directory to your path

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

then you can access test directory in your function.

from test.transformers.test_dyt import TorchDyT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved mhc_sinkhorn_ref / mhc_coeffs_ref into test_mhc.py and updated benchmarks to import from there. I also added the repo root to sys.path in the benchmark scripts.

Comment on lines 8 to 15
from liger_kernel.triton.mhc import mhc_mm_norm_bwd
from liger_kernel.triton.mhc import mhc_mm_norm_fwd
from liger_kernel.triton.mhc import mhc_post_res_bwd
from liger_kernel.triton.mhc import mhc_post_res_fwd
from liger_kernel.triton.mhc import mhc_pre_bwd
from liger_kernel.triton.mhc import mhc_pre_fwd
from liger_kernel.triton.mhc import mhc_sinkhorn_bwd
from liger_kernel.triton.mhc import mhc_split_sinkhorn_fwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's define them in this file directly for codebase structure. Happy to discuss other approaches if you feel the file become too large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I moved the Triton kernels into mhc.py and removed mhc.py.

def test_mhc_mini_lm_convergence():
set_seed(0)

device = "cuda"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use liger_kernel.utils.infer_device to get device since we support mutliple backends, not just cuda

HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mhc_mini_lm.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll disucss with folks and decide whether testing mhc architecture or not.

yukiu00 and others added 4 commits February 5, 2026 10:00
- Remove backward-compatible alias functions (108 lines)
- Add docstring and comments to _post_res_default_meta
- Use Union[] instead of | for Python 3.9 compatibility
- Replace assert with ValueError for better debugging
- Add runtime warning when hc > 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can move this part to test_mhc.py. No need to check whether losses decreases, we only have to check whether the outputs generated by two models, one with torch refs and the other with liger's mhc components, are close enough, using torch.testing.assert_close() or utils.assert_verblose_allclose()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Applied the same restructuring to benchmark_mhc_lm.py.

Comment on lines +330 to +334
int(tmax),
float(rms_eps),
float(pre_eps),
float(sinkhorn_eps),
float(post_mult),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any considerations why we have to cast them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The casts (int(tmax), float(rms_eps), etc.) convert config scalars from tensors or numpy types into plain Python types, ensuring they are not accidentally included in the autograd graph. Added a clarifying comment at L322.

Wraps a layer F: [..., C] -> [..., C] with mHC residual streams: [..., HC, C].

Args:
layer: module applied to the aggregated stream input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an example in the docstring? It's still unclear to me how we can wrap the existing modules with LigerMHC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added an Example:: section in the docstring (L61-77) that shows how to wrap a linear layer with LigerMHC and how to use it inside a transformer block.

- Remove all `assert x.is_cuda` checks for multi-backend support
- Eliminate `_as_scalar()` GPU sync by passing alpha params as pointers
  to Triton kernels (use `tl.load()` instead of `.item()`)
- Merge duplicate TC/TF32 kernel pairs into unified kernels with
  `CAST_FP32: tl.constexpr` compile-time flag (~180 lines removed)
- Replace `view(N, ...)` with `view(-1, ...)` across autograd Functions
- Move functional APIs from `ops.mhc` to `transformers.functional`
- Improve `LigerMHC` docstring with architecture, args, and examples
- Rewrite `benchmark_mhc.py` to standard framework (run_benchmarks)
- Use `infer_device()` in convergence test instead of hardcoded "cuda"
… skipif

- benchmark_mhc.py: pass all config params via extra_benchmark_configs
  following the DPO benchmark pattern
- test_mhc_mini_lm.py: remove redundant torch.cuda.is_available() skipif
  (supports_bfloat16() already covers this case)
…st_mhc.py for improved organization and maintainability of convergence tests.
- Remove CUDA-only skipif decorators from tests for multi-backend support
- Simplify _flatten_tokens to return x_shape, remove _unflatten_tokens helper
- Remove dead Makefile reference to deleted test_mhc_mini_lm.py
…istency and clarity. Update function signatures in `_post_res_default_meta` and `_post_res_meta` to use `Tuple` from the `typing` module.
- Remove no-op mask=True from Sinkhorn backward kernels
- Drop unused rms_eps/pre_eps from ctx.meta in coeffs backward
- Remove redundant .contiguous() calls inside @ensure_contiguous methods
- Simplify grad_x reshape to use x_shape directly
- Simplify device detection in LigerMHC to try/except pattern
- Replace torch.allclose with assert_verbose_allclose in tests
- Standardize seed to set_seed(42) across all tests
- Merge test_mhc_coeffs_allow_fp32 into test_mhc_coeffs_forward_backward
- Add backward coverage to test_mhc_pre_and_post_res_match_reference
- Widen bf16 tolerance for layer.weight.grad and phi.grad in module test
- Move hardcoded B into extra_benchmark_configs (benchmark_mhc.py)
- Rename MiniMHCLM to BenchMiniMHCLM in benchmark_mhc_lm.py
- Split _build_models into single-provider _build_model
@yukiu00
Copy link
Contributor Author

yukiu00 commented Feb 9, 2026

Thank you for the detailed review, @Tcc0403! All review comments have been addressed — I've replied to each one individually above.

In addition to the review feedback, I also made a few minor cleanups:

  • Replaced torch.allclose with assert_verbose_allclose across all tests
  • Standardized random seeds to set_seed(42)
  • Merged the test_mhc_coeffs_allow_fp32 test into the main parametrized test_mhc_coeffs_forward_backward
  • Added backward pass coverage to test_mhc_pre_and_post_res_match_reference
  • Removed redundant .contiguous() calls inside @ensure_contiguous-decorated methods
  • Simplified ctx.meta tuple by dropping unused fields (rms_eps, pre_eps)
  • Simplified device detection in LigerMHC.__init__
  • Renamed benchmark model class to BenchMiniMHCLM to avoid name collision with test
  • Split _build_models into single-provider _build_model to avoid unnecessary GPU allocation
  • Moved hardcoded B into extra_benchmark_configs in benchmark_mhc.py

Please let me know if there's anything else that needs attention!

Comment on lines +1418 to +1419
x_shape = x.shape
x_flat, _ = _flatten_tokens(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put it together for readability

Suggested change
x_shape = x.shape
x_flat, _ = _flatten_tokens(x)
assert x.dim() >= 3, "x must be [..., HC, C]"
x_shape = x.shape
x_flat = x.contiguous().view(-1, x.shape[-2], x.shape[-1])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removed _flatten_tokens and inlined the assert + view logic at all 3 call sites.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add mHC (Manifold-Constrained Hyper-Connections) fused kernels to Liger-Kernel

2 participants