Skip to content

[GRPO] add grpo loss types#993

Open
kashif wants to merge 9 commits intolinkedin:mainfrom
kashif:update-grpo-loss-type
Open

[GRPO] add grpo loss types#993
kashif wants to merge 9 commits intolinkedin:mainfrom
kashif:update-grpo-loss-type

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Jan 1, 2026

Summary

Add various GRPO loss types.

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@kashif
Copy link
Contributor Author

kashif commented Jan 1, 2026

grpo_comparison_subplots

@yukiu00
Copy link
Contributor

yukiu00 commented Feb 9, 2026

Hi @kashif , thank you for the work on this PR — it was very helpful as a reference.

Since this PR has been inactive for over a month and the vLLM IS correction issue (#1082) is actively affecting users running GRPOTrainer with use_vllm=True + use_liger_kernel=True, I've opened #1088 with a minimal, focused fix for this issue.

Differences from this PR

This PR (#993) #1088
Scope vllm_is_ratio + sequence-level IS for Triton + loss reduction refactoring (~600 lines) vllm_is_ratio only (~245 lines)
Triton sequence-level IS New _grpo_loss_fwd_kernel_seq / _bwd_kernel_seq kernels Not included (Triton path remains token-level only; sequence-level works via chunked loss path)
Loss reduction Moved into GrpoLossFunction.forward Kept in triton_grpo_loss wrapper (existing architecture)
CISPO / SAPO support Not included (pre-#1074) Supported across all 6 loss types
dr_grpo max_completion_length Falls back to seq_len if None Keeps existing ValueError (required)

The core mechanism (applying per_token_loss *= vllm_is_ratio before KL penalty) is the same in both PRs.

@kashif
Copy link
Contributor Author

kashif commented Feb 9, 2026

thanks @yukiu00 can you take everything you need from this PR and i am happy to close this if your fixes are a superset? the reason I had paused this was the transformers v5 changes and other changes in the grpo Trainer that i was waiting on.

@yukiu00
Copy link
Contributor

yukiu00 commented Feb 9, 2026

Thanks @kashif!

To clarify, #1088 is currently a subset (focused on the immediate vLLM fix) rather than a superset.

However, I would be happy to create a separate follow-up PR to incorporate the rest of your features (such as the sequence-level kernels) after #1088 is merged. Your implementation is excellent, and I definitely want to see it in the codebase.

Would you be okay if I ported your work to a new PR to continue the integration?

@kashif
Copy link
Contributor Author

kashif commented Feb 9, 2026

@yukiu00 ok understood, ok lets try to get your fix in asap as its urgent and then i can rebase my stuff

Tcc0403 pushed a commit that referenced this pull request Feb 9, 2026
## Summary

Fixes the **primary cause** (item 1) of #1082 —
`LigerFusedLinearGRPOLoss` produces ~100x larger `grad_norm` than TRL's
non-Liger path when using vLLM.

**Root cause:** TRL's `GRPOTrainer` applies `per_token_loss *=
importance_sampling_ratio`
([source](https://github.com/huggingface/trl/blob/v0.27.2/trl/trainer/grpo_trainer.py#L2351-L2352))
to correct for distribution mismatch from vLLM's rejection/stratified
sampling. Liger-Kernel had no mechanism to accept or apply this
correction, so the IS ratio was silently ignored, resulting in
uncorrected (and much larger) gradients.

**This is a high-priority fix** — any user running `GRPOTrainer` with
`use_vllm=True` and `use_liger_kernel=True` is affected, and the
resulting ~100x gradient mismatch can cause training instability or
divergence.

### Changes

- Add optional `vllm_is_ratio` parameter (`[B, T]` tensor or `None`) to
both code paths:
- **Chunked loss path**: `LigerFusedLinearGRPOLoss`,
`LigerFusedLinearGRPOFunction`, `ppo_loss_fn`, and the base class
`LigerFusedLinearPPOBase` chunking pipeline
- **Triton kernel path**: `triton_grpo_loss`, `GrpoLossFunction`, and
the Triton fwd/bwd kernels (`_grpo_loss_fwd_kernel`,
`_grpo_loss_bwd_kernel`)
- The IS correction is applied **after** PPO clipped loss computation
and **before** KL penalty, matching TRL's behavior exactly
- `vllm_is_ratio=None` (default) preserves existing behavior — no
breaking changes
- Works with all loss types: `grpo`, `dapo`, `bnpo`, `dr_grpo`, `cispo`,
`sapo`

### Verification

With `IS_RATIO=0.01`, the `grad_norm` ratio matches exactly:

```
Chunked loss path:
  grad_norm WITHOUT vllm_is_ratio: 1.052219e-01
  grad_norm WITH    vllm_is_ratio: 1.052219e-03
  ratio: 0.010000 ✓

Triton path:
  grad_norm WITHOUT vllm_is_ratio: 1.461673e-02
  grad_norm WITH    vllm_is_ratio: 1.461673e-04
  ratio: 0.010000 ✓
```

## Test plan

- [x] Extended existing `test_correctness` in
`test/chunked_loss/test_grpo_loss.py` with `use_vllm_is_ratio`
parametrize — covers all 6 loss types × 2 IS levels × 2 beta values ×
with/without vllm_is_ratio
- [x] Added `test_grpo_loss_with_vllm_is_ratio` in
`test/transformers/test_grpo_loss.py` — compares Triton output against
PyTorch reference with IS correction, plus `vllm_is_ratio=None` ==
`vllm_is_ratio=ones` identity check
- [x] All existing tests continue to pass (no regressions)
- [x] `make checkstyle` passes

## Related

- Reference implementation: #993
- Issue: #1082
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants