Conversation
|
Hi @kashif , thank you for the work on this PR — it was very helpful as a reference. Since this PR has been inactive for over a month and the vLLM IS correction issue (#1082) is actively affecting users running Differences from this PR
The core mechanism (applying |
|
thanks @yukiu00 can you take everything you need from this PR and i am happy to close this if your fixes are a superset? the reason I had paused this was the transformers v5 changes and other changes in the grpo Trainer that i was waiting on. |
|
Thanks @kashif! To clarify, #1088 is currently a subset (focused on the immediate vLLM fix) rather than a superset. However, I would be happy to create a separate follow-up PR to incorporate the rest of your features (such as the sequence-level kernels) after #1088 is merged. Your implementation is excellent, and I definitely want to see it in the codebase. Would you be okay if I ported your work to a new PR to continue the integration? |
|
@yukiu00 ok understood, ok lets try to get your fix in asap as its urgent and then i can rebase my stuff |
## Summary Fixes the **primary cause** (item 1) of #1082 — `LigerFusedLinearGRPOLoss` produces ~100x larger `grad_norm` than TRL's non-Liger path when using vLLM. **Root cause:** TRL's `GRPOTrainer` applies `per_token_loss *= importance_sampling_ratio` ([source](https://github.com/huggingface/trl/blob/v0.27.2/trl/trainer/grpo_trainer.py#L2351-L2352)) to correct for distribution mismatch from vLLM's rejection/stratified sampling. Liger-Kernel had no mechanism to accept or apply this correction, so the IS ratio was silently ignored, resulting in uncorrected (and much larger) gradients. **This is a high-priority fix** — any user running `GRPOTrainer` with `use_vllm=True` and `use_liger_kernel=True` is affected, and the resulting ~100x gradient mismatch can cause training instability or divergence. ### Changes - Add optional `vllm_is_ratio` parameter (`[B, T]` tensor or `None`) to both code paths: - **Chunked loss path**: `LigerFusedLinearGRPOLoss`, `LigerFusedLinearGRPOFunction`, `ppo_loss_fn`, and the base class `LigerFusedLinearPPOBase` chunking pipeline - **Triton kernel path**: `triton_grpo_loss`, `GrpoLossFunction`, and the Triton fwd/bwd kernels (`_grpo_loss_fwd_kernel`, `_grpo_loss_bwd_kernel`) - The IS correction is applied **after** PPO clipped loss computation and **before** KL penalty, matching TRL's behavior exactly - `vllm_is_ratio=None` (default) preserves existing behavior — no breaking changes - Works with all loss types: `grpo`, `dapo`, `bnpo`, `dr_grpo`, `cispo`, `sapo` ### Verification With `IS_RATIO=0.01`, the `grad_norm` ratio matches exactly: ``` Chunked loss path: grad_norm WITHOUT vllm_is_ratio: 1.052219e-01 grad_norm WITH vllm_is_ratio: 1.052219e-03 ratio: 0.010000 ✓ Triton path: grad_norm WITHOUT vllm_is_ratio: 1.461673e-02 grad_norm WITH vllm_is_ratio: 1.461673e-04 ratio: 0.010000 ✓ ``` ## Test plan - [x] Extended existing `test_correctness` in `test/chunked_loss/test_grpo_loss.py` with `use_vllm_is_ratio` parametrize — covers all 6 loss types × 2 IS levels × 2 beta values × with/without vllm_is_ratio - [x] Added `test_grpo_loss_with_vllm_is_ratio` in `test/transformers/test_grpo_loss.py` — compares Triton output against PyTorch reference with IS correction, plus `vllm_is_ratio=None` == `vllm_is_ratio=ones` identity check - [x] All existing tests continue to pass (no regressions) - [x] `make checkstyle` passes ## Related - Reference implementation: #993 - Issue: #1082

Summary
Add various GRPO loss types.
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence