diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index cfc24e3053..0ae23cc804 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -125,6 +125,8 @@ title: PPO - local: prm_trainer title: PRM + - local: rtpo_trainer + title: RTPO - local: winrate_callback title: WinRateCallback - local: xpo_trainer diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 5d8c38472d..4fd6954a85 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -352,7 +352,7 @@ See [Experimental - GFPO](experimental#gfpo). **📜 Paper**: https://huggingface.co/papers/2507.06448 -A novel policy gradient algorithm that encourages VLMs to learn to perceive while learning to reason. This is a TRL adaptation. The TRL implementation is not the official one provided by the authors. +A novel policy gradient algorithm that encourages VLMs to learn to perceive while learning to reason. This is a TRL adaptation of PAPO. Note that this is not the official implementation. The official code can be found in [MikeWangWZHL/PAPO](https://github.com/MikeWangWZHL/PAPO). ```python @@ -412,6 +412,19 @@ training_args = GRPOConfig( ... ) ``` +### Reverse Thinking Policy Optimization + +**📰 Blog**: https://github.com/SolarWindRider/avr/blob/main/README.md + +Current GRPO-based RL methods require the model to autonomously generate a full chain-of-thought before producing the final answer. +However, many training datasets already contain **complete, high-quality reasoning traces** that the model could benefit from. + +RTPO is designed to: + +* Utilize existing reasoning traces as **auxiliary CoT** to support early-stage rollouts. +* Force the model to gradually reconstruct its own reasoning by **shortening the auxiliary CoT** step by step. +* Enable a *reverse learning schedule*: the model first learns to output correct answers, then progressively learns how to reason. + ### DeepSeek-V3.2: Pushing the Frontier of Open Large Language Models diff --git a/trl/experimental/rtpo/__init__.py b/trl/experimental/rtpo/__init__.py new file mode 100644 index 0000000000..e53399125f --- /dev/null +++ b/trl/experimental/rtpo/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .rtpo_config import RTPOConfig +from .rtpo_trainer import RTPOTrainer diff --git a/trl/experimental/rtpo/rtpo_config.py b/trl/experimental/rtpo/rtpo_config.py new file mode 100644 index 0000000000..c8c3e272c4 --- /dev/null +++ b/trl/experimental/rtpo/rtpo_config.py @@ -0,0 +1,73 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ...trainer.grpo_config import GRPOConfig + + +@dataclass +class RTPOConfig(GRPOConfig): + """ + Configuration class for PAPOTrainer. + + PAPO (Perception-Aware Policy Optimization) extends GRPO/DAPO for multimodal reasoning by adding an implicit + perception loss and double entropy regularization. + + Args: + schedule_type (`str`, *optional*, defaults to `"linear"`): + Choose a schedule type for AnnealingScheduler to control thinking guidance length. Supports: `"linear"`, + `"cosine"`, `"exponential"`, `"piecewise"`, `"constant"` + + direction (`str`, *optional*, defaults to `"down"`): + Direction of the annealing schedule. + - `"down"`: Schedule value goes from 1.0 → 0.0 + - `"up"`: Schedule value goes from 0.0 → 1.0 + Supports: `"up"`, `"down"` + + decay_rate (`float`, *optional*, defaults to 5.0): + The decay rate used when `schedule_type` is set to `"exponential"`. Higher values result in faster decay. + + milestones (`list[float]`, *optional*, defaults to `[0.3, 0.6, 0.9]`): + Milestones (progress points between 0 and 1) for piecewise linear schedule. Only used when `schedule_type` + is set to `"piecewise"`. Must be in ascending order and within [0, 1] range. + + values (`list[float]`, *optional*, defaults to `[0.2, 0.5, 0.8, 1.0]`): + Schedule values corresponding to the milestones and boundaries. Only used when `schedule_type` is set to + `"piecewise"`. Length must be `len(milestones) + 1`. For `direction="down"`, values typically decrease; for + `direction="up"`, values typically increase. + + value (`float`, *optional*, defaults to 1.0): + Constant value for constant schedule. Only used when `schedule_type` is set to `"constant"`. + """ + + schedule_type: str = field(default="linear") + direction: str = field(default="down") + decay_rate: float | None = field(default=None) + milestones: list[float] | None = field(default=None) + values: list[float] | None = field(default=None) + value: float | None = field(default=None) + + def __post_init__(self): + # 根据 schedule_type 设置默认值 + if self.schedule_type == "exponential" and self.decay_rate is None: + self.decay_rate = 5.0 + elif self.schedule_type == "piecewise": + if self.milestones is None: + self.milestones = [0.3, 0.6, 0.9] + if self.values is None: + self.values = [0.2, 0.5, 0.8, 1.0] + elif self.schedule_type == "constant" and self.value is None: + self.value = 1.0 + super().__post_init__() diff --git a/trl/experimental/rtpo/rtpo_trainer.py b/trl/experimental/rtpo/rtpo_trainer.py new file mode 100644 index 0000000000..52d13c6dc7 --- /dev/null +++ b/trl/experimental/rtpo/rtpo_trainer.py @@ -0,0 +1,249 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This class should be added to trl.data_utils when the RTPO method is stable enough to join trl.trainer +import math +from typing import Any + +import torch +from datasets import Dataset, IterableDataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin + +from ...data_utils import is_conversational +from ...extras.profiling import profiling_decorator +from ...trainer.grpo_trainer import GRPOTrainer, RewardFunc +from ...trainer.utils import ( + shuffle_sequence_dict, + split_pixel_values_by_grid, + split_tensor_dict, + unsplit_pixel_values_by_grid, +) +from .rtpo_config import RTPOConfig + + +class AnnealingScheduler: + """ + General annealing scheduler, can be used like a learning rate scheduler. Supports: linear, cosine, exponential, + constant, piecewise Can be 'up' (0→1) or 'down' (1→0). + """ + + def __init__( + self, + total_steps: int | float, + schedule_type: str = "linear", + direction: str = "down", # "down" means 1→0 annealing + **kwargs: Any, + ) -> None: + self.total_steps = total_steps + self.schedule_type = schedule_type + self.direction = direction + self.kwargs = kwargs # extra params like decay_rate, milestones, etc. + + def __call__(self, step: int | float) -> float: + """Return annealing factor based on current step.""" + ratio = min(step / max(self.total_steps, 1), 1.0) + + if self.schedule_type == "linear": + value = ratio + + elif self.schedule_type == "cosine": + # Cosine annealing: starts fast, slows down later + value = 0.5 * (1 - math.cos(math.pi * ratio)) + + elif self.schedule_type == "exponential": + # Exponential annealing: f(t)=1 - exp(-k*t) + k = self.kwargs.get("decay_rate", 5.0) + value = 1 - math.exp(-k * ratio) + + elif self.schedule_type == "piecewise": + milestones = self.kwargs.get("milestones", [0.3, 0.6, 0.9]) + values = self.kwargs.get("values", [0.2, 0.5, 0.8, 1.0]) + for i, m in enumerate(milestones): + if ratio < m: + value = values[i] + break + else: + value = values[-1] + + elif self.schedule_type == "constant": + value = self.kwargs.get("value", 1.0) + + else: + raise ValueError(f"Unknown schedule_type: {self.schedule_type}") + + # Apply direction: up (0→1) or down (1→0) + if self.direction == "down": + return 1.0 - value + elif self.direction == "up": + return value + else: + raise ValueError(f"Invalid direction: {self.direction}") + + +def think_guigence_anneal( + generation_batch: list[dict[str, Any]], + anneal_factor: float, + tokenizer_or_processor: PreTrainedTokenizerBase | ProcessorMixin, +) -> list[dict[str, Any]]: + tokenizer = getattr(tokenizer_or_processor, "tokenizer", tokenizer_or_processor) + if is_conversational(generation_batch[0]): + for i in range(len(generation_batch)): + if generation_batch[i]["prompt"][-1]["role"] == "assistant": + think = generation_batch[i]["prompt"][-1]["content"] + tokens = tokenizer.encode(think) + anneal = tokenizer.decode(tokens[: int(anneal_factor * len(tokens))], skip_special_tokens=True) + generation_batch[i]["prompt"][-1]["content"] = anneal + return generation_batch + + +def drop_assistant_content(generation_batch: list[dict[str, Any]]) -> list[dict[str, Any]]: + if is_conversational(generation_batch[0]): + for i in range(len(generation_batch)): + if generation_batch[i]["prompt"][-1]["role"] == "assistant": + generation_batch[i]["prompt"].pop() + return generation_batch + + +class RTPOTrainer(GRPOTrainer): + """ + Trainer for Reverse Thinking Policy Optimization (RTPO). Example: + + ```python + from datasets import load_dataset + from trl import RTPOTrainer, RTPOConfig + + dataset = load_dataset("your-vlm-dataset", split="train") + + + def reward_func(completions, **kwargs): + # Your reward function for multimodal reasoning + return [compute_reward(c) for c in completions] + + + config = RTPOConfig( + loss_type="grpo", # Use GRPO as base + perception_loss_weight=0.1, + mask_ratio=0.3, + ) + + trainer = RTPOTrainer( + model="Qwen/Qwen2-VL-2B-Instruct", + reward_funcs=reward_func, + args=config, + train_dataset=dataset, + ) + + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained (must be a vision-language model). + reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): + Reward functions for computing rewards (same as GRPO). + args ([`PAPOConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. Must include "prompt" and "image" columns. + eval_dataset: Same requirements as train_dataset. + processing_class: Processing class (tokenizer/processor) for the model. + reward_processing_classes: Processing classes for reward models. + callbacks: Training callbacks. + optimizers: Optimizer and scheduler tuple. + peft_config: PEFT configuration if using parameter-efficient fine-tuning. + """ + + _tag_names = ["trl", "rtpo"] + _name = "RTPO" + + def __init__( + self, + model: str | PreTrainedModel, + reward_funcs: RewardFunc | list[RewardFunc], + args: RTPOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks=None, + optimizers=(None, None), + peft_config=None, + ): + # Initialize with default RTPO config if not provided + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = RTPOConfig(f"{model_name}-RTPO") + + total_steps = args.max_steps or (len(self.get_train_dataloader()) * args.num_train_epochs) + self.anneal_scheduler = AnnealingScheduler( + total_steps=total_steps, + schedule_type=args.schedule_type, # linear / cosine / exponential / piecewise / constant + direction=args.direction, # up 0 -> 1 / down 1-> 0 + decay_rate=args.decay_rate, # corresponding to the exponential parameter + milestones=args.milestones, # corresponding to the piecewise parameter + values=args.values, # corresponding to the piecewise parameter + value=args.value, # corresponding to the constant parameter + ) + + # Initialize parent GRPO trainer + super().__init__( + model=model, + reward_funcs=reward_funcs, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=reward_processing_classes, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + # Prepares inputs for model training/evaluation by managing completion generation and batch handling. + # During training: + # - Receives the local generation batch (Per-GPU batch size × steps per generation) + # from the modified training dataloader instead of the standard local batch + # - Generates completions once for the entire generation batch and splits it into batches of size + # `per_device_train_batch_size` + # - Buffers these completions and returns the appropriate slice for the current accumulation step + # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) + # During evaluation: + # - The input is treated as a standard local batch (no accumulation, no multiple iterations) + # - Completions are generated for each batch without buffering or reuse + # Returns a single local batch in both cases. + + mode = "train" if self.model.training else "eval" + if mode == "train": + generation_batch = think_guigence_anneal( + generation_batch, self.anneal_scheduler(self.state.global_step), self.processing_class + ) + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + # self._buffered_inputs=None can occur when resuming from a checkpoint + generation_batch = self._generate_and_score_completions(generation_batch) + generation_batch = split_pixel_values_by_grid(generation_batch) + generation_batch = shuffle_sequence_dict(generation_batch) + generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) + self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] + inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] + self._step += 1 + else: + generation_batch = drop_assistant_content(generation_batch) + # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence + # local generation batch == local eval batch + inputs = self._generate_and_score_completions(generation_batch) # no thinking guigence on eval. + return inputs