Skip to content

Conversation

@yukiu00
Copy link
Contributor

@yukiu00 yukiu00 commented Feb 4, 2026

Summary

This PR optimizes norm backward passes for LoRA/PEFT training by skipping weight/bias gradient computation when parameters are frozen (requires_grad=False).

Closes #1067

Motivation

When using LoRA/PEFT, normalization weights are typically frozen but Liger was still computing their gradients. This wastes computation and memory, especially at large hidden sizes common in modern LLMs. See the linked issue for detailed motivation and benchmarks.

Changes

Kernel optimizations:

  • Add compute_dW/compute_dB flags to backward kernels (as tl.constexpr for dead code elimination)
  • Skip gradient buffer allocation when not needed
  • Check ctx.needs_input_grad in all norm backward passes

Affected ops:

  • RMSNorm
  • FusedAddRMSNorm
  • LayerNorm
  • GroupNorm
  • PolyNorm

Tests:

  • Add frozen weight/bias test coverage for all norm ops

Benchmarks:

  • Add freeze_weight option to RMSNorm benchmark
  • Add mixed benchmark (RMSNorm + LoRA Linear)

Bug fixes:

  • Fix dS_out None check in fused_add_rms_norm_backward

Benchmark Results

RTX 3090, bf16, M=2048

Hidden Size Backward Speedup Full Speedup
H=1024 1.25× 1.12×
H=4096 1.11× 1.05×
H=16384 1.37× 1.22×
H=32768 3.12× 2.41×

API Impact

  • No public API changes
  • Internal *_backward helpers now accept compute_dW/compute_dB flags

Test Plan

  • All existing norm tests pass
  • New frozen weight/bias tests pass for all norm ops
  • Lint/format checks pass
# Run frozen weight tests
pytest test/transformers/test_rms_norm.py test/transformers/test_layer_norm.py \
       test/transformers/test_group_norm.py test/transformers/test_poly_norm.py \
       test/transformers/test_fused_add_rms_norm.py -v -k "frozen"

# Run benchmarks
python benchmark/scripts/benchmark_rms_norm.py --overwrite
python benchmark/scripts/benchmark_rms_norm_mixed.py --overwrite

When using LoRA/PEFT, normalization weights are typically frozen but
gradients were still being computed. This PR skips dW/dB computation
when parameters have requires_grad=False, providing significant speedups
at larger hidden sizes (up to 3x faster backward pass at H=32768).

Changes:
- Add compute_dW/compute_dB flags to backward kernels (tl.constexpr)
- Skip gradient buffer allocation when not needed
- Check ctx.needs_input_grad in all norm backward passes
- Add frozen weight/bias test coverage for all norm ops
- Add mixed RMSNorm+LoRA benchmark
- Fix dS_out None check in fused_add_rms_norm_backward

Affected ops: RMSNorm, FusedAddRMSNorm, LayerNorm, GroupNorm, PolyNorm

No public API changes.
@yukiu00 yukiu00 force-pushed the fix/norm-freeze-grads branch from 23a0db5 to 8777c9f Compare February 4, 2026 10:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Skip dW/dB computation in norm ops when weight/bias is frozen (LoRA/PEFT optimization)

1 participant