Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4590205
feature: rtpo trainer
SolarWindRider Nov 4, 2025
7e9e121
Merge branch 'huggingface:main' into trpo
SolarWindRider Nov 5, 2025
b224652
Merge branch 'huggingface:main' into trpo
SolarWindRider Nov 5, 2025
42d6fdf
Merge branch 'huggingface:main' into rtpo
SolarWindRider Nov 18, 2025
0a85fd8
Merge branch 'huggingface:main' into rtpo
SolarWindRider Dec 5, 2025
28d77da
del paper info
SolarWindRider Dec 9, 2025
0cb2666
Merge branch 'main' into rtpo
SolarWindRider Dec 9, 2025
4745541
Merge branch 'main' into rtpo
SolarWindRider Dec 10, 2025
50fab70
rtpo doc paper_index toctree
SolarWindRider Dec 11, 2025
039be42
Merge branch 'rtpo' of https://github.com/SolarWindRider/trl-my into …
SolarWindRider Dec 11, 2025
5f4cfbf
Merge branch 'main' into rtpo
SolarWindRider Dec 11, 2025
b12f56e
make precommit edit
SolarWindRider Dec 11, 2025
5ba394b
Merge branch 'rtpo' of https://github.com/SolarWindRider/trl-my into …
SolarWindRider Dec 11, 2025
ffb0aaa
make precommit
SolarWindRider Dec 11, 2025
368f84c
Merge branch 'main' into rtpo
SolarWindRider Dec 12, 2025
71469e9
Merge branch 'main' into rtpo
SolarWindRider Dec 15, 2025
77d937b
Merge branch 'main' into rtpo
SolarWindRider Dec 16, 2025
76f2756
add super().__post_init__()
SolarWindRider Dec 16, 2025
f2a73d4
Merge branch 'huggingface:main' into rtpo
SolarWindRider Dec 16, 2025
3edd37d
Merge branch 'main' into rtpo
SolarWindRider Dec 17, 2025
2ac2d24
Merge branch 'main' into rtpo
SolarWindRider Dec 18, 2025
e8d7f3a
Merge branch 'main' into rtpo
SolarWindRider Dec 18, 2025
dda13c3
Merge branch 'main' into rtpo
SolarWindRider Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@
title: PPO
- local: prm_trainer
title: PRM
- local: rtpo_trainer
title: RTPO
- local: winrate_callback
title: WinRateCallback
- local: xpo_trainer
Expand Down
15 changes: 14 additions & 1 deletion docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ See [Experimental - GFPO](experimental#gfpo).

**📜 Paper**: https://huggingface.co/papers/2507.06448

A novel policy gradient algorithm that encourages VLMs to learn to perceive while learning to reason. This is a TRL adaptation. The TRL implementation is not the official one provided by the authors.
A novel policy gradient algorithm that encourages VLMs to learn to perceive while learning to reason.
This is a TRL adaptation of PAPO. Note that this is not the official implementation. The official code can be found in [MikeWangWZHL/PAPO](https://github.com/MikeWangWZHL/PAPO).

```python
Expand Down Expand Up @@ -412,6 +412,19 @@ training_args = GRPOConfig(
...
)
```
### Reverse Thinking Policy Optimization

**📰 Blog**: https://github.com/SolarWindRider/avr/blob/main/README.md

Current GRPO-based RL methods require the model to autonomously generate a full chain-of-thought before producing the final answer.
However, many training datasets already contain **complete, high-quality reasoning traces** that the model could benefit from.

RTPO is designed to:

* Utilize existing reasoning traces as **auxiliary CoT** to support early-stage rollouts.
* Force the model to gradually reconstruct its own reasoning by **shortening the auxiliary CoT** step by step.
* Enable a *reverse learning schedule*: the model first learns to output correct answers, then progressively learns how to reason.


### DeepSeek-V3.2: Pushing the Frontier of Open Large Language Models

Expand Down
17 changes: 17 additions & 0 deletions trl/experimental/rtpo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .rtpo_config import RTPOConfig
from .rtpo_trainer import RTPOTrainer
73 changes: 73 additions & 0 deletions trl/experimental/rtpo/rtpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from ...trainer.grpo_config import GRPOConfig


@dataclass
class RTPOConfig(GRPOConfig):
"""
Configuration class for PAPOTrainer.

PAPO (Perception-Aware Policy Optimization) extends GRPO/DAPO for multimodal reasoning by adding an implicit
perception loss and double entropy regularization.

Args:
schedule_type (`str`, *optional*, defaults to `"linear"`):
Choose a schedule type for AnnealingScheduler to control thinking guidance length. Supports: `"linear"`,
`"cosine"`, `"exponential"`, `"piecewise"`, `"constant"`

direction (`str`, *optional*, defaults to `"down"`):
Direction of the annealing schedule.
- `"down"`: Schedule value goes from 1.0 → 0.0
- `"up"`: Schedule value goes from 0.0 → 1.0
Supports: `"up"`, `"down"`

decay_rate (`float`, *optional*, defaults to 5.0):
The decay rate used when `schedule_type` is set to `"exponential"`. Higher values result in faster decay.

milestones (`list[float]`, *optional*, defaults to `[0.3, 0.6, 0.9]`):
Milestones (progress points between 0 and 1) for piecewise linear schedule. Only used when `schedule_type`
is set to `"piecewise"`. Must be in ascending order and within [0, 1] range.

values (`list[float]`, *optional*, defaults to `[0.2, 0.5, 0.8, 1.0]`):
Schedule values corresponding to the milestones and boundaries. Only used when `schedule_type` is set to
`"piecewise"`. Length must be `len(milestones) + 1`. For `direction="down"`, values typically decrease; for
`direction="up"`, values typically increase.

value (`float`, *optional*, defaults to 1.0):
Constant value for constant schedule. Only used when `schedule_type` is set to `"constant"`.
"""

schedule_type: str = field(default="linear")
direction: str = field(default="down")
decay_rate: float | None = field(default=None)
milestones: list[float] | None = field(default=None)
values: list[float] | None = field(default=None)
value: float | None = field(default=None)

def __post_init__(self):
# 根据 schedule_type 设置默认值
if self.schedule_type == "exponential" and self.decay_rate is None:
self.decay_rate = 5.0
elif self.schedule_type == "piecewise":
if self.milestones is None:
self.milestones = [0.3, 0.6, 0.9]
if self.values is None:
self.values = [0.2, 0.5, 0.8, 1.0]
elif self.schedule_type == "constant" and self.value is None:
self.value = 1.0
super().__post_init__()
249 changes: 249 additions & 0 deletions trl/experimental/rtpo/rtpo_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This class should be added to trl.data_utils when the RTPO method is stable enough to join trl.trainer
import math
from typing import Any

import torch
from datasets import Dataset, IterableDataset
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin

from ...data_utils import is_conversational
from ...extras.profiling import profiling_decorator
from ...trainer.grpo_trainer import GRPOTrainer, RewardFunc
from ...trainer.utils import (
shuffle_sequence_dict,
split_pixel_values_by_grid,
split_tensor_dict,
unsplit_pixel_values_by_grid,
)
from .rtpo_config import RTPOConfig


class AnnealingScheduler:
"""
General annealing scheduler, can be used like a learning rate scheduler. Supports: linear, cosine, exponential,
constant, piecewise Can be 'up' (0→1) or 'down' (1→0).
"""

def __init__(
self,
total_steps: int | float,
schedule_type: str = "linear",
direction: str = "down", # "down" means 1→0 annealing
**kwargs: Any,
) -> None:
self.total_steps = total_steps
self.schedule_type = schedule_type
self.direction = direction
self.kwargs = kwargs # extra params like decay_rate, milestones, etc.

def __call__(self, step: int | float) -> float:
"""Return annealing factor based on current step."""
ratio = min(step / max(self.total_steps, 1), 1.0)

if self.schedule_type == "linear":
value = ratio

elif self.schedule_type == "cosine":
# Cosine annealing: starts fast, slows down later
value = 0.5 * (1 - math.cos(math.pi * ratio))

elif self.schedule_type == "exponential":
# Exponential annealing: f(t)=1 - exp(-k*t)
k = self.kwargs.get("decay_rate", 5.0)
value = 1 - math.exp(-k * ratio)

elif self.schedule_type == "piecewise":
milestones = self.kwargs.get("milestones", [0.3, 0.6, 0.9])
values = self.kwargs.get("values", [0.2, 0.5, 0.8, 1.0])
for i, m in enumerate(milestones):
if ratio < m:
value = values[i]
break
else:
value = values[-1]

elif self.schedule_type == "constant":
value = self.kwargs.get("value", 1.0)

else:
raise ValueError(f"Unknown schedule_type: {self.schedule_type}")

# Apply direction: up (0→1) or down (1→0)
if self.direction == "down":
return 1.0 - value
elif self.direction == "up":
return value
else:
raise ValueError(f"Invalid direction: {self.direction}")


def think_guigence_anneal(
generation_batch: list[dict[str, Any]],
anneal_factor: float,
tokenizer_or_processor: PreTrainedTokenizerBase | ProcessorMixin,
) -> list[dict[str, Any]]:
tokenizer = getattr(tokenizer_or_processor, "tokenizer", tokenizer_or_processor)
if is_conversational(generation_batch[0]):
for i in range(len(generation_batch)):
if generation_batch[i]["prompt"][-1]["role"] == "assistant":
think = generation_batch[i]["prompt"][-1]["content"]
tokens = tokenizer.encode(think)
anneal = tokenizer.decode(tokens[: int(anneal_factor * len(tokens))], skip_special_tokens=True)
generation_batch[i]["prompt"][-1]["content"] = anneal
return generation_batch


def drop_assistant_content(generation_batch: list[dict[str, Any]]) -> list[dict[str, Any]]:
if is_conversational(generation_batch[0]):
for i in range(len(generation_batch)):
if generation_batch[i]["prompt"][-1]["role"] == "assistant":
generation_batch[i]["prompt"].pop()
return generation_batch


class RTPOTrainer(GRPOTrainer):
"""
Trainer for Reverse Thinking Policy Optimization (RTPO). Example:

```python
from datasets import load_dataset
from trl import RTPOTrainer, RTPOConfig

dataset = load_dataset("your-vlm-dataset", split="train")


def reward_func(completions, **kwargs):
# Your reward function for multimodal reasoning
return [compute_reward(c) for c in completions]


config = RTPOConfig(
loss_type="grpo", # Use GRPO as base
perception_loss_weight=0.1,
mask_ratio=0.3,
)

trainer = RTPOTrainer(
model="Qwen/Qwen2-VL-2B-Instruct",
reward_funcs=reward_func,
args=config,
train_dataset=dataset,
)

trainer.train()
```

Args:
model (`Union[str, PreTrainedModel]`):
Model to be trained (must be a vision-language model).
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
Reward functions for computing rewards (same as GRPO).
args ([`PAPOConfig`], *optional*, defaults to `None`):
Configuration for this trainer. If `None`, a default configuration is used.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. Must include "prompt" and "image" columns.
eval_dataset: Same requirements as train_dataset.
processing_class: Processing class (tokenizer/processor) for the model.
reward_processing_classes: Processing classes for reward models.
callbacks: Training callbacks.
optimizers: Optimizer and scheduler tuple.
peft_config: PEFT configuration if using parameter-efficient fine-tuning.
"""

_tag_names = ["trl", "rtpo"]
_name = "RTPO"

def __init__(
self,
model: str | PreTrainedModel,
reward_funcs: RewardFunc | list[RewardFunc],
args: RTPOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
callbacks=None,
optimizers=(None, None),
peft_config=None,
):
# Initialize with default RTPO config if not provided
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = RTPOConfig(f"{model_name}-RTPO")

total_steps = args.max_steps or (len(self.get_train_dataloader()) * args.num_train_epochs)
self.anneal_scheduler = AnnealingScheduler(
total_steps=total_steps,
schedule_type=args.schedule_type, # linear / cosine / exponential / piecewise / constant
direction=args.direction, # up 0 -> 1 / down 1-> 0
decay_rate=args.decay_rate, # corresponding to the exponential parameter
milestones=args.milestones, # corresponding to the piecewise parameter
values=args.values, # corresponding to the piecewise parameter
value=args.value, # corresponding to the constant parameter
)

# Initialize parent GRPO trainer
super().__init__(
model=model,
reward_funcs=reward_funcs,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_classes=reward_processing_classes,
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
)

@profiling_decorator
def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
# Prepares inputs for model training/evaluation by managing completion generation and batch handling.
# During training:
# - Receives the local generation batch (Per-GPU batch size × steps per generation)
# from the modified training dataloader instead of the standard local batch
# - Generates completions once for the entire generation batch and splits it into batches of size
# `per_device_train_batch_size`
# - Buffers these completions and returns the appropriate slice for the current accumulation step
# - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations)
# During evaluation:
# - The input is treated as a standard local batch (no accumulation, no multiple iterations)
# - Completions are generated for each batch without buffering or reuse
# Returns a single local batch in both cases.

mode = "train" if self.model.training else "eval"
if mode == "train":
generation_batch = think_guigence_anneal(
generation_batch, self.anneal_scheduler(self.state.global_step), self.processing_class
)
generate_every = self.args.steps_per_generation * self.num_iterations
if self._step % generate_every == 0 or self._buffered_inputs is None:
# self._buffered_inputs=None can occur when resuming from a checkpoint
generation_batch = self._generate_and_score_completions(generation_batch)
generation_batch = split_pixel_values_by_grid(generation_batch)
generation_batch = shuffle_sequence_dict(generation_batch)
generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
self._step += 1
else:
generation_batch = drop_assistant_content(generation_batch)
# In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
# local generation batch == local eval batch
inputs = self._generate_and_score_completions(generation_batch) # no thinking guigence on eval.
return inputs