Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7baf106
Minimalistic GRPO optimizer step implementation. (to be numerically v…
jacobthebanana Sep 8, 2025
acbba68
GRPO for Openai-Agents implementation. (Missing vec-inf component.)
jacobthebanana Sep 15, 2025
658c4c2
Added LangFuse integration to GRPO Agent RL.
jacobthebanana Sep 15, 2025
c1a088f
Added AGENTS.md for programming agents.
jacobthebanana Sep 15, 2025
195bcdc
Added submitit vllm
jacobthebanana Oct 27, 2025
c53f1bc
Added submitit vllm
jacobthebanana Oct 27, 2025
d2cb117
Updated trainer setup.
jacobthebanana Oct 28, 2025
2fdda56
Made PerTokenProbs torch-only and cleaned up
jacobthebanana Oct 28, 2025
b7f14b4
Connected GRPO pipeline and eliminated log_prob host transfer
jacobthebanana Oct 29, 2025
cc64f35
Made served_model_name follow model name for observability.
jacobthebanana Oct 29, 2025
2e2d777
Cleaned up test entrypoints and usernames.
jacobthebanana Oct 29, 2025
47ff5bb
Cleaned up and added documentations.
jacobthebanana Oct 29, 2025
32a1430
Cleaned up documentations.
jacobthebanana Oct 29, 2025
4500427
Deleted .vscode and AGENTS.md
jacobthebanana Oct 29, 2025
3500556
Cleaned up run_in_venv.sh
jacobthebanana Oct 29, 2025
3230e11
Cleaned up TODO and duplicate gather_with_progress
jacobthebanana Oct 29, 2025
25d1bd6
Cleaned up references to vec-inf.
jacobthebanana Oct 29, 2025
1121b42
Cleaned up GRPOMetrics.
jacobthebanana Oct 29, 2025
79bcede
Deleted unused `rate_limited` in async_utils and `indexed` in submiti…
jacobthebanana Oct 29, 2025
d3b38bb
Merge branch 'main' into jjt-rlvr-grpo
jacobthebanana Nov 23, 2025
f9f0d1f
Moved starter -> templates
jacobthebanana Nov 23, 2025
2a644bd
Refactored GRPO config- replaced argparse with hydra/omegaconf
jacobthebanana Nov 24, 2025
24bcdc6
Implemented support for local executors.
jacobthebanana Nov 26, 2025
034ba04
Moved backprop step into local executor to free up GPU memory.
jacobthebanana Nov 26, 2025
66139fa
Applied linting fixes.
jacobthebanana Nov 26, 2025
d02bd4a
Added instructions for RLVR
jacobthebanana Nov 26, 2025
ba0ea9b
Implemented support for RLVR GRPO tracking via LangFuse dataset.
jacobthebanana Nov 27, 2025
9a862b6
- Fixed race conditions in vLLM worker and submitit batch
jacobthebanana Dec 2, 2025
a0d2141
Refactored to use relative imports and move into rl/rlvr
jacobthebanana Dec 9, 2025
6ce0130
Config and checkpoint clean-up fixes
jacobthebanana Dec 10, 2025
50fc334
Merge remote-tracking branch 'origin/main' into jjt-rlvr-grpo
jacobthebanana Dec 11, 2025
2066506
Renamed vaughan -> bon_echo
jacobthebanana Dec 11, 2025
5d053df
Fixed OmegaConf integration for GRPO
jacobthebanana Dec 13, 2025
fdf2d9a
Added comments regarding moving agents into GRPOTrainer class.
jacobthebanana Dec 13, 2025
fc71b1e
Fixed tokenizer handling in rollout
jacobthebanana Dec 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,5 @@ cache/
.Trashes
ehthumbs.db
Thumbs.db

/_submitit_logs/
31 changes: 23 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,32 @@ description = "Starter templates that launch on Slurm via Hydra + Submitit"
authors = [{name = "Farnaz Kohankhaki", email = "[email protected]"}]
requires-python = "==3.12.*"
dependencies = [
"hydra-core>=1.3.2,<1.4",
"hydra-submitit-launcher>=1.2.0,<1.5",
"submitit>=1.5.0,<2.0",
"torch>=2.0.0,<2.6.0",
"transformers>=4.36.0,<4.52.0",
"datasets>=2.16.0,<3.6.0",
"accelerate>=0.26.0,<1.0.0",
"hydra-core>=1.3.2",
'hydra-submitit-launcher @ git+https://github.com/facebookresearch/hydra.git@8018b394532bae4cc78638159a848414072f4c6d#subdirectory=plugins/hydra_submitit_launcher',
"submitit>=1.5.0",
"torch>=2.0.0",
"transformers>=4.36.0",
"datasets>=2.16.0",
"accelerate>=0.26.0",
"pyarrow==16.1.0",
"pillow>=10.0.0,<11.0.0",
"pillow>=10.0.0",
"ruff>=0.1.0,<1.0.0",
"pre-commit>=3.0.0,<4.0.0",
"pydantic>=2.11.7",
"rich>=14.1.0",
"openai-agents>=0.2.11",
"basedpyright>=1.31.4",
"langfuse>=3.3.4",
"nest-asyncio>=1.6.0",
"pydantic-ai[logfire]>=1.0.6",
"vllm>=0.11.0",
]

[dependency-groups]
dev = [
"basedpyright>=1.31.4",
"pytest-asyncio>=1.3.0",
"pytest>=9.0.1",
]

[tool.ruff]
Expand Down
28 changes: 28 additions & 0 deletions templates/configs/compute/bon_echo/run_in_container.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

# Script for running vLLM on Bon Echo
# Example:
# bash run_in_container.sh uv run vllm serve /model-weights/Qwen3-8B
source ~/.bashrc
source /opt/lmod/lmod/init/bash
export MODULEPATH=/opt/modulefiles:/pkgs/modulefiles:/pkgs/environment-modules

module load singularity-ce
export SINGULARITYENV_SLURM_CONF=/opt/slurm/etc/slurm.conf
export SINGULARITYENV_PATH="/opt/slurm/bin:$PATH"
export SINGULARITYENV_LD_LIBRARY_PATH="/opt/slurm/lib:/opt/slurm/lib64:/opt/munge/lib:/opt/munge/lib64:${LD_LIBRARY_PATH:-}"

unset VIRTUAL_ENV
unset VIRTUAL_ENV_PROMPT
if [[ -n "$UV_PROJECT_ENVIRONMENT" ]]; then
source ${UV_PROJECT_ENVIRONMENT}/bin/activate
fi

singularity exec \
--nv \
--bind /model-weights:/model-weights \
--bind /projects/llm:/projects/llm \
--bind $HOME:$HOME \
--bind $SCRATCH:$SCRATCH \
/projects/llm/unsloth-vllm-trl-latest.sif \
$@
16 changes: 16 additions & 0 deletions templates/configs/compute/bon_echo/run_in_venv.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

# Script for running vLLM on Bon Echo
# Example:
# bash run_in_container.sh uv run vllm serve /model-weights/Qwen3-8B
source ~/.bashrc

unset VIRTUAL_ENV
unset VIRTUAL_ENV_PROMPT
if [[ -n "$UV_PROJECT_ENVIRONMENT" ]]; then
source ${UV_PROJECT_ENVIRONMENT}/bin/activate
fi

nvidia-smi

$@
6 changes: 6 additions & 0 deletions templates/configs/user.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# user:
# slurm:
# account: vector
# additional_parameters:
# qos: m2

user:
slurm:
account: vector
4 changes: 4 additions & 0 deletions templates/src/mlp/single/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def main(cfg: DictConfig):
OmegaConf.resolve(hydra_config)
OmegaConf.save(hydra_config, save_path)

import json

print(json.dumps(hydra_config.__dict__, indent=2, default=str))

# Run the trainer with the run config
checkpointable_trainer = CheckpointableMLPTrainer()
return checkpointable_trainer(cfg)
Expand Down
58 changes: 58 additions & 0 deletions templates/src/rl/rlvr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# RL with Verifiable Reward (RLVR) Reference Implementations

This folder contains scripts for running RLVR algorithms on LLMs on the Vector cluster.

Supported algorithms:

- GRPO

Features:

- Compatibility with Chat Completion models.
- LLM-as-a-judge for more involved reward verifications.
- Optimized for heterogenous compute environments- run backpropagation on H100/A100, and use L40S/A40/RTX8000 for rollout and LLM judge via dedicated SLURM jobs (see [submitit_vllm.py](submitit_vllm.py) for details.)

Current limitations and TODO items:

- Single-GPU finetuning only.
- Backprop GPU does not participate in rollouts.
- Rollout GPUs might sit idle when all rollouts are done and only eval is pending.
- Verify support for function calling via Chat Completion.
- Integrate LangFuse datasets to track eval traces across steps.

## Vaughan (Bon Echo) User? Extra setup steps required

(Skip this section if you are running on Killarney)

This reference implementation depends on `vllm>=0.11.0`. vLLM does not provide pre-built wheels for older Linux distros, such as the one running on Vaughan (Bon Echo). Follow the steps in [vaughan_setup.md](vaughan_setup.md) before continuing.

## Run Trainer

```bash
uv run python \
-m rl.rlvr.grpo.launch \
--multirun \
compute=killarney/l40s_1x \
requeue=off \
trainer.num_epochs=10 \
trainer.data.train_split="train\[:100\]" \
trainer.run_name="grpo_gsm8k_dry_run"
```

## Adapting to your workflow

Configurable options:

- GRPO hyperparameters
- dataset (the example uses `openai/gsm8k`)
- evaluation scheme and LLM judge setup

## Optional- Observability Integration via LangFuse

Set up LangFuse to track the output of your models as training proceeds.

```bash
export LANGFUSE_SECRET_KEY="sk-lf-..."
export LANGFUSE_PUBLIC_KEY="pk-lf-..."
export LANGFUSE_HOST="https://us.cloud.langfuse.com"
```
178 changes: 178 additions & 0 deletions templates/src/rl/rlvr/agents_integration/examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""
OpenAI Agents SDK demo: function tool `get_weather` + minimal agent.

Setup (with astral-uv):
uv venv && uv pip install -U pip
uv add openai-agents pydantic
# Or, if you prefer pip: `pip install openai-agents pydantic`

Environment:
export OPENAI_API_KEY=... # Required for the default OpenAI client

Run:
uv run python agents_sdk_get_weather_demo.py
"""

from __future__ import annotations

from datetime import date, datetime, timedelta, timezone
from typing import Literal

from agents import Agent, Runner, function_tool
from pydantic import BaseModel, Field


class WeatherReport(BaseModel):
"""Structured output for a weather report.

Attributes
----------
city:
Echo of the requested city (canonicalized).
unit:
"c" for Celsius or "f" for Fahrenheit.
temperature:
Air temperature in the requested unit.
feels_like:
Apparent temperature in the requested unit.
condition:
One of: "clear", "partly cloudy", "cloudy", "rain", "snow", "windy".
humidity:
Relative humidity percentage (0–100).
wind_kph:
Wind speed in kilometers per hour.
observation_time:
UTC timestamp when the reading was generated.
"""

city: str
unit: Literal["c", "f"]
temperature: float
feels_like: float
condition: Literal["clear", "partly cloudy", "cloudy", "rain", "snow", "windy"]
humidity: int = Field(ge=0, le=100)
wind_kph: float = Field(ge=0)
observation_time: datetime


@function_tool
def get_weather(
city: str,
unit: Literal["c", "f"] = "c",
when: Literal["now", "today", "tomorrow"] = "now",
) -> str:
"""Return a deterministic, mock weather report for demos.

The function is *offline* and *stable across runs* for a given `(city, date)`
so it's ideal for showcasing **function-tool** calls without network flakiness.

Args:
city:
Human-readable city name (e.g., "Vancouver").
unit:
Temperature unit: "c" for Celsius, "f" for Fahrenheit. Defaults to "c".
when:
Time window for the report: "now", "today", or "tomorrow".
Defaults to "now".

Returns
-------
JSON string representing a `WeatherReport`.
"""
canonical = city.strip()

# City baselines (°C). Extend this mapping to taste.
baselines: dict[str, float] = {
"vancouver": 14.0,
"new york": 12.0,
"london": 11.0,
"singapore": 28.0,
"shanghai": 28.0,
"auckland": 20.0,
"tokyo": 17.0,
"paris": 13.0,
"san francisco": 16.0,
"berlin": 12.0,
"mexico city": 19.0,
}

key = canonical.lower()
base_c = baselines.get(key, 15.0)

# Reference date for deterministic seeding
today = date.today()
ref_date = today if when in ("now", "today") else today + timedelta(days=1)

# Seeded pseudo-randoms derived from (city, date)
seed = abs(hash(f"{key}|{ref_date.isoformat()}"))

def prand(a: float, b: float, salt: int) -> float:
# Deterministic pseudo-random in [a, b]
return a + (seed ^ salt) % 10 / 9.0 * (b - a)

temp_c = base_c + prand(-4.0, 4.0, 0xA5A5) - 0.5
humidity = int(round(prand(40, 90, 0xB6B6)))
wind_kph = round(prand(0.0, 30.0, 0xC7C7), 1)

band = seed % 100
if band < 20:
condition = "clear"
elif band < 45:
condition = "partly cloudy"
elif band < 65:
condition = "cloudy"
elif band < 85:
condition = "rain"
elif band < 95:
condition = "windy"
else:
condition = "snow"

feels_c = temp_c - 0.1 * wind_kph + 0.02 * (humidity - 50)

def to_unit(tc: float, u: Literal["c", "f"]) -> float:
return round(tc if u == "c" else (tc * 9 / 5 + 32), 1)

report = WeatherReport(
city=canonical,
unit=unit,
temperature=to_unit(temp_c, unit),
feels_like=to_unit(feels_c, unit),
condition=condition, # type: ignore[arg-type]
humidity=humidity,
wind_kph=wind_kph,
observation_time=datetime.now(timezone.utc),
)

# Agents SDK tools should return a string (or something that stringifies cleanly).
return report.model_dump_json()


# --- Minimal agent wiring ----------------------------------------------------
weather_agent = Agent(
name="Weather Helper",
instructions=(
"You answer weather questions. When the user asks about weather, "
"call the `get_weather` tool. If it returns JSON, parse it and reply "
"concisely with temperature, feels-like, condition, and units."
),
tools=[get_weather], # register the function tool
)


def main() -> None:
"""Run a single demo turn with the agent and print the final output."""
# Example inputs that strongly encourage tool use
user_inputs: list[str] = [
"What's the weather in Vancouver today in celsius?",
"NYC now, in Fahrenheit — include feels-like and wind, please.",
]

for i, prompt in enumerate(user_inputs, start=1):
print(f"\n=== Demo turn {i} ===")
result = Runner.run_sync(weather_agent, prompt)
print(result.final_output)


if __name__ == "__main__":
main()
32 changes: 32 additions & 0 deletions templates/src/rl/rlvr/agents_integration/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Utils for silencing OpenAI Python SDK error 401."""

import logging


class IgnoreOpenAI401Filter(logging.Filter):
"""
A logging filter that excludes specific OpenAI client error messages.

Filters out: 'ERROR:openai.agents:[non-fatal] Tracing client error 401'
"""

def filter(self, record: logging.LogRecord) -> bool:
"""Define filter logic."""
msg = record.getMessage()
return not (
record.levelname == "ERROR"
and record.name == "openai.agents"
and "[non-fatal] Tracing client error 401" in msg
)


def set_up_logging():
"""Set up Logging and Warning levels."""
root_logger = logging.getLogger()
filter_ = IgnoreOpenAI401Filter()

if not root_logger.handlers:
logging.basicConfig()

for handler in root_logger.handlers:
handler.addFilter(filter_)
Loading