Skip to content

Conversation

@SolarWindRider
Copy link
Contributor

πŸš€ Reverse Thinking Policy Optimization (RTPO)

This PR introduces Reverse Thinking Policy Optimization (RTPO) β€” a new RL training method for LLMs built on top of GRPOTrainer.

πŸ” Motivation

Current GRPO-based RL methods require the model to autonomously generate a full chain-of-thought before producing the final answer.
However, many training datasets already contain complete, high-quality reasoning traces that the model could benefit from.

RTPO is designed to:

  • Utilize existing reasoning traces as auxiliary CoT to support early-stage rollouts.
  • Force the model to gradually reconstruct its own reasoning by shortening the auxiliary CoT step by step.
  • Enable a reverse learning schedule: the model first learns to output correct answers, then progressively learns how to reason.

🧠 Method Overview

RTPO modifies the standard GRPO rollout process:

Full Auxiliary CoT Injection

At rollout step 0, the full reasoning chain from the dataset is concatenated into the input prompt.

Model behavior:

  • Only needs to generate the final answer.
  • Benefits from a high-quality reasoning scaffold.

Reverse Annealing of Auxiliary CoT

As training steps increase, RTPO gradually removes tokens from the end of the auxiliary CoT based on a configurable schedule:

full_reasoning β†’ partial_reasoning β†’ short_hint β†’ empty

Expected Model behavior:

  • "Fill in" the removed reasoning process.
  • Learns to produce longer reasoning as annealing progresses.

Interesting Finding: Emergent Shorter Reasoning

Unexpectedly, RTPO also teaches the model to shorten its reasoning:

  • When the model does not regenerate the removed tokens, and instead directly outputs the correct final answer,
  • Over training, the model consistently generates shorter, more efficient reasoning chains.

More experiments are ongoing and will be included later.


πŸ“¦ Files Added / Modified

  • trl/experimental/rtpo/__init__.py
  • trl/experimental/rtpo/rtpo_config.py
  • trl/experimental/rtpo/rtpo_trainer.py

πŸ§ͺ Example Usage

Grabed from my repo named AVR

import os
from argparse import ArgumentParser
from peft import LoraConfig
from trl.experimental.rtpo import RTPOConfig, RTPOTrainer
from utils.universal import set_seed, get_dataset, model_processor, reward_fn

# ======= Basic Environment & Paths =======
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
set_seed(42)

# ========== Arguments ===================================
parser = ArgumentParser()
parser.add_argument("--model_path", type=str, default="../Downloads/Models/Qwen/Qwen2.5-VL-7B-Instruct")
parser.add_argument("--loss_type", type=str, default="grpo", choices=["dapo", "grpo", "dr_grpo"])
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--think_process_key", type=str, default="gold_analysis", choices=["gold_analysis", "explanation"])  # Assistant thinking guide (bootstrap)
parser.add_argument("--anneal_schedule", type=str, default="cosine")

args = parser.parse_args()
print(args)

image_root = "../datas/VisuRiddles"
train_json_path = "../datas/VisuRiddles/syndata.json"


# ======= Model and Dataset =======
model, processor = model_processor(args.model_path)
train_ds, eval_ds = get_dataset(image_root, train_json_path, args.think_process_key, True) # RTPO requires chain-of-thought guidance.

# ======= LoRA Configuration (target_modules consistent with your SFT) =======
lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], init_lora_weights=True)


# ======= RTPO Configuration =======
"""
Multi-GPU: It is recommended to use torchrun for launching (see command at the end).
You can also switch to FSDP/Deepspeed.
Set remove_unused_columns=False to keep image columns for internal multimodal input construction.
"""
grpo_config = RTPOConfig(
    # Annealing parameters
    schedule_type=args.anneal_schedule,
    direction="down",
    # GRPO parameters
    loss_type=args.loss_type,
    output_dir=args.output_dir,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    num_generations=8,  # Sample 8 responses per prompt
    top_k=20,
    num_train_epochs=3,  # Can be replaced with num_train_epochs
    learning_rate=5e-5,  # RL typically uses smaller LR; adjust as needed
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    logging_steps=5,
    save_strategy="steps",
    save_steps=50,
    save_only_model=True,
    eval_strategy="steps",
    eval_steps=50,
    report_to="swanlab",
    remove_unused_columns=False,
    fp16=False,
    bf16=True,
    max_prompt_length=8192,
    max_completion_length=8192,
    fsdp="full_shard auto_wrap",
    fsdp_config={
        "mixed_precision": "bf16",
        "forward_prefetch": True,
        "use_orig_params": False,
        "use_cpu": True,
        "offload_params": True,
        "offload_optimizer": True,
        "enable_gradient_checkpointing": True,
    },
)

# ======= Build RTPOTrainer =======
trainer = RTPOTrainer(
    model=model,
    processing_class=processor,  # Allows internal automatic construction of multimodal inputs from prompt+image
    args=grpo_config,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    reward_funcs=[reward_fn],  # Multiple rewards can also be stacked
    peft_config=lora_config,  # LoRA low-rank fine-tuning
)

# ======= Training =======
if __name__ == "__main__":
    # Before starting training, manually convert LoRA module parameters
    from peft.peft_model import PeftModel

    if isinstance(trainer.model, PeftModel):
        for param in trainer.model.parameters():
            if param.requires_grad:
                param.data = param.data.bfloat16()
        print("Successfully converted trainable LoRA parameters to bf16.")

    trainer.train()
    # Save LoRA weights at the end
    trainer.save_model()

βœ… Status

  • Core RTPO algorithm implemented
  • Fully integrated with TRL
  • Supports distributed training
  • Configurable annealing schedules
  • Documentation and examples
  • Unit tests

πŸ™Œ Request for Review

Comment on lines +28 to +30
schedule_type (`str`,defaults to linear):
Choose a schedule type for AnnealingScheduler to control thinking guidance length.
Supports: linear, cosine, exponential, piecewise, constant
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
schedule_type (`str`,defaults to linear):
Choose a schedule type for AnnealingScheduler to control thinking guidance length.
Supports: linear, cosine, exponential, piecewise, constant
schedule_type (`str`, *optional*, defaults to `"linear"`):
Choose a schedule type for AnnealingScheduler to control thinking guidance length. Supports: `"linear"`,
`"cosine, `"exponential"`, `"piecewise"`, `"constant"`.

can you try to try to align the docstring with the rest of the codebase? Above is an example

Constant value for constant schedule.
"""

schedule_type: str = "linear"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, we usually use dataclasses.field. You can the GRPOConfig as an example.

@qgallouedec
Copy link
Member

Thanks for the PR!

Can you:

  • Add a tiny subsection in the Paper Index
  • Add a section to the documentation.
  • Apply the style (make precommit)

You can take your other PR #4334 as an example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants