diff --git a/examples/vl_dpo_demo_gemma3.ipynb b/examples/vl_dpo_demo_gemma3.ipynb new file mode 100644 index 000000000..872b90b80 --- /dev/null +++ b/examples/vl_dpo_demo_gemma3.ipynb @@ -0,0 +1,667 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "79e6e5f1" + }, + "source": [ + "# Fine-tuning a Visual Language Model (VLM) using DPO\n", + "\n", + "This notebook demonstrates how to fine-tune a Visual Language Model (VLM), specifically the Gemma 3-4B-it model, using the Direct Preference Optimization (DPO) algorithm.\n", + "\n", + "The key steps involved are:\n", + "\n", + "1. **Setup and Installations**: Install necessary libraries and dependencies.\n", + "2. **Model Loading**: Load the pre-trained Gemma 3-4B-it model.\n", + "3. **LoRA Application**: Apply Low-Rank Adaptation (LoRA) to the model for efficient fine-tuning.\n", + "4. **Data Loading and Preprocessing**: Load the RLAIF-V dataset and preprocess it for VLM training, including handling images and tokenizing text.\n", + "5. **DPO Training**: Set up and run the DPO training loop to fine-tune the model based on preference data (chosen and rejected responses).\n", + "6. **Logging and Visualization**: Log training metrics and visualize the training progress.\n", + "\n", + "The goal is to train the VLM to better align with human preferences by optimizing directly on pairs of preferred and dispreferred responses." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bjhtP5lgJOSg" + }, + "outputs": [], + "source": [ + "!pip install -q kagglehub\n", + "\n", + "!pip install -q tensorflow\n", + "!pip install -q tensorboardX\n", + "!pip install -q grain\n", + "!pip install -q git+https://github.com/google/tunix\n", + "!pip install -q git+https://github.com/google/qwix\n", + "\n", + "!pip uninstall -q -y flax\n", + "!pip install -q git+https://github.com/google/flax.git\n", + "\n", + "!pip install -q huggingface_hub\n", + "!pip install -q datasets\n", + "!pip3 install jaxtyping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TezTyGV-Kgpi" + }, + "outputs": [], + "source": [ + "import dataclasses\n", + "import json\n", + "import os\n", + "import types\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import optax\n", + "import qwix\n", + "from datasets import load_dataset\n", + "from flax import nnx\n", + "from huggingface_hub import snapshot_download\n", + "from PIL import Image\n", + "\n", + "from tunix.generate import tokenizer_adapter as tokenizer_lib\n", + "from tunix.models.gemma3 import model as gemma3_model_lib\n", + "from tunix.models.gemma3 import params_safetensors as params_safetensors_lib\n", + "from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig, TrainingInput" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "p6TZeaJMKmgt" + }, + "outputs": [], + "source": [ + "GEMMA_TOKENIZER_PATH = \"gs://gemma-data/tokenizers/tokenizer_gemma3.model\"\n", + "model_id = \"google/gemma-3-4b-it\"\n", + "IMAGE_SIZE = 896\n", + "# ====== Data ======\n", + "TRAIN_DATA_DIR = \"./data/train\"\n", + "TEST_DATA_DIR = \"./data/test\"\n", + "TRAIN_FRACTION = 1.0\n", + "\n", + "INTERMEDIATE_CKPT_DIR = \"/content/intermediate_ckpt/\"\n", + "# ====== LoRA ======\n", + "RANK = 32\n", + "ALPHA = 16.0\n", + "\n", + "# ====== Sharding ======\n", + "MESH = [(1, 1), (\"fsdp\", \"tp\")]\n", + "\n", + "MAX_PROMPT_LENGTH = 512\n", + "TOTAL_GENERATION_STEPS = 256\n", + "TEMPERATURE = 0.7\n", + "TOP_P = 1.0\n", + "TOP_K = 50\n", + "BETA = 0.1\n", + "\n", + "# === AdamW, warmup, cosine scheduler ===\n", + "LEARNING_RATE = 5e-6\n", + "B1 = 0.9\n", + "B2 = 0.99\n", + "WEIGHT_DECAY = 0.1\n", + "\n", + "# == Cosine decay with warmup scheduler ==\n", + "# Linearly increase learning rate from 0. to 5e-6 in the first 10% training\n", + "# steps, and then gradually decrease the learning rate to 0 using cosine\n", + "# scheduler.\n", + "EVAL_EVERY_N_STEPS = 50\n", + "MAX_STEPS = 5000\n", + "BATCH_SIZE = 1\n", + "EPOCHS = 10\n", + "\n", + "WARMUP_STEPS = 0.1 * MAX_STEPS\n", + "# == Grad clipping ==\n", + "# Grad clipping to prevent large gradients. Found this\n", + "# important to keep KL divergence in check.\n", + "MAX_GRAD_NORM = 0.1\n", + "\n", + "# ====== Inference ======\n", + "GENERATION_CONFIGS = {\n", + " # greedy search\n", + " \"greedy\": {\"temperature\": 1e-4, \"top_k\": 1, \"top_p\": 1.0},\n", + " # some randomness\n", + " \"standard\": {\"temperature\": 0.7, \"top_k\": 50, \"top_p\": 0.95},\n", + " # liberal\n", + " \"liberal\": {\"temperature\": 0.85, \"top_k\": 2000, \"top_p\": 1.0},\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lUbAc5WPKwYh" + }, + "outputs": [], + "source": [ + "!huggingface-cli login" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zhZCozOTKnmZ" + }, + "outputs": [], + "source": [ + "ignore_patterns = [\n", + " \"*.pth\", # Ignore PyTorch .pth weight files\n", + "]\n", + "print(f\"Downloading {model_id} from Hugging Face...\")\n", + "local_model_path = snapshot_download(\n", + " repo_id=model_id, ignore_patterns=ignore_patterns\n", + ")\n", + "print(f\"Model successfully downloaded to: {local_model_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "McqkpUNbKnja" + }, + "outputs": [], + "source": [ + "MODEL_CP_PATH = local_model_path\n", + "\n", + "model_config = gemma3_model_lib.ModelConfig.gemma3_4b()\n", + "model_config = dataclasses.replace(\n", + " model_config, multimodal=True, num_embed=262208\n", + ")\n", + "\n", + "MESH = [(1, 1), (\"fsdp\", \"tp\")]\n", + "mesh = jax.make_mesh(*MESH)\n", + "with mesh:\n", + " gemma3 = params_safetensors_lib.create_model_from_safe_tensors(\n", + " MODEL_CP_PATH, model_config, mesh\n", + " )\n", + " nnx.display(gemma3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wT6HC4URKnaQ" + }, + "outputs": [], + "source": [ + "gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a-0swaMJPSPb" + }, + "outputs": [], + "source": [ + "def get_lora_model(base_model, mesh):\n", + " lora_provider = qwix.LoraProvider(\n", + " module_path=(\n", + " \".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|\"\n", + " \".*attn_vec_einsum\"\n", + " ),\n", + " rank=RANK,\n", + " alpha=ALPHA,\n", + " )\n", + "\n", + " rngs = nnx.Rngs(params=jax.random.PRNGKey(42))\n", + " model_input = dict(base_model.get_model_input())\n", + "\n", + " # HARD-CODED for Gemma3 + SigLIP in Tunix\n", + " if base_model.config.multimodal:\n", + " bsz = model_input[\"last_tokens\"].shape[0]\n", + " model_input[\"pixel_values\"] = jnp.zeros(\n", + " (bsz, 896, 896, 3), dtype=jnp.float32 # NHWC\n", + " )\n", + "\n", + " lora_model = qwix.apply_lora_to_model(\n", + " base_model, lora_provider, rngs=rngs, **model_input\n", + " )\n", + "\n", + " with mesh:\n", + " state = nnx.state(lora_model)\n", + " pspecs = nnx.get_partition_spec(state)\n", + " nnx.update(lora_model, jax.lax.with_sharding_constraint(state, pspecs))\n", + "\n", + " return lora_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tEH8AioSPSOR" + }, + "outputs": [], + "source": [ + "# Policy model\n", + "lora_gemma = get_lora_model(gemma3, mesh=mesh)\n", + "nnx.display(lora_gemma)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "efDhGwrtPVZi" + }, + "outputs": [], + "source": [ + "SPLIT = \"train[:5000]\"\n", + "import jax.image as jimg\n", + "ds = load_dataset(\"openbmb/RLAIF-V-Dataset\", split=SPLIT)\n", + "cols = [\"image\", \"question\", \"chosen\", \"rejected\"]\n", + "ds = ds.remove_columns([c for c in ds.column_names if c not in cols])\n", + "\n", + "\n", + "def _pick_one_image(img_field):\n", + " \"\"\"Return a single PIL.Image from the dataset's image field.\"\"\"\n", + " x = img_field\n", + " if isinstance(x, list):\n", + " if not x: # empty list, skip later\n", + " return None\n", + " x = x[0]\n", + " if isinstance(x, Image.Image):\n", + " return x.convert(\"RGB\")\n", + " arr = np.array(x)\n", + " if arr.ndim == 3:\n", + " return Image.fromarray(arr).convert(\"RGB\")\n", + " return None\n", + "\n", + "\n", + "_CLIP_MEAN = jnp.array([0.48145466, 0.4578275, 0.40821073], dtype=jnp.float32)\n", + "_CLIP_STD = jnp.array([0.26862954, 0.26130258, 0.27577711], dtype=jnp.float32)\n", + "\n", + "\n", + "def preprocess_image(\n", + " images_uint8: jnp.ndarray,\n", + " image_size: int,\n", + " *,\n", + " mean=_CLIP_MEAN,\n", + " std=_CLIP_STD,\n", + ") -> jnp.ndarray:\n", + " \"\"\"Resize + normalize images for SigLIP.\n", + "\n", + " Args:\n", + " images_uint8: [B,H,W,3] or [H,W,3], dtype uint8.\n", + " image_size: output resolution (image_size x image_size).\n", + " mean/std: per-channel normalization arrays.\n", + "\n", + " Returns:\n", + " float32 array [B, image_size, image_size, 3]\n", + " \"\"\"\n", + " x = images_uint8\n", + " if x.dtype != jnp.uint8:\n", + " raise ValueError(f\"Expected uint8 images, got {x.dtype}\")\n", + "\n", + " # Add batch if needed.\n", + " if x.ndim == 3:\n", + " x = x[None, ...] # [1,H,W,3]\n", + " if x.ndim != 4 or x.shape[-1] != 3:\n", + " raise ValueError(f\"Expected [B,H,W,3], got shape {x.shape}\")\n", + "\n", + " b, h, w, c = x.shape\n", + " # Resize to target square (simple bilinear; if you prefer center-crop+resize, do that here)\n", + " x = jimg.resize(\n", + " x, (b, image_size, image_size, c), method=\"bilinear\", antialias=True\n", + " )\n", + "\n", + " # [0,1] -> normalize\n", + " x = x.astype(jnp.float32) / 255.0\n", + " mean = jnp.asarray(mean, dtype=jnp.float32).reshape((1, 1, 1, 3))\n", + " std = jnp.asarray(std, dtype=jnp.float32).reshape((1, 1, 1, 3))\n", + " x = (x - mean) / std\n", + " return x\n", + "\n", + "\n", + "def preprocess_item(ex):\n", + " img = _pick_one_image(ex[\"image\"])\n", + " if img is None:\n", + " return {\n", + " \"pixel_values\": None,\n", + " \"question\": ex[\"question\"],\n", + " \"chosen\": ex[\"chosen\"],\n", + " \"rejected\": ex[\"rejected\"],\n", + " }\n", + " arr = np.array(img, dtype=np.uint8)[None, ...] # [1,H,W,3]\n", + " px = preprocess_image(jnp.asarray(arr), IMAGE_SIZE) # [1,S,S,3] float32\n", + " return {\n", + " \"pixel_values\": np.asarray(px[0]), # [S,S,3]\n", + " \"question\": ex[\"question\"],\n", + " \"chosen\": ex[\"chosen\"],\n", + " \"rejected\": ex[\"rejected\"],\n", + " }\n", + "\n", + "\n", + "PAD = gemma_tokenizer.pad_id()\n", + "EOS = gemma_tokenizer.eos_id()\n", + "\n", + "\n", + "def _left_pad_np(ids, L, pad=PAD):\n", + " ids = ids[-L:] if len(ids) > L else [pad] * (L - len(ids)) + ids\n", + " return np.asarray(ids, dtype=np.int32)\n", + "\n", + "\n", + "def _right_pad_np(ids, L, pad=PAD):\n", + " ids = ids[:L]\n", + " ids = ids + [pad] * (L - len(ids))\n", + " return np.asarray(ids, dtype=np.int32)\n", + "\n", + "\n", + "def _make_mask(ids, pad=PAD):\n", + " return (ids != pad).astype(np.int32)\n", + "\n", + "\n", + "IMG = 262144\n", + "N_IMG = 256\n", + "def pack_prompt_with_image(text_ids):\n", + " keep = MAX_PROMPT_LENGTH - N_IMG\n", + " text_ids = text_ids[-keep:] if keep > 0 else []\n", + " ids = text_ids + [IMG] * N_IMG\n", + " return _left_pad_np(ids, MAX_PROMPT_LENGTH)\n", + "\n", + "\n", + "def numpy_batches_vlm(dataset, batch_size=1, shuffle=True, seed=0, epochs=None):\n", + " rng = np.random.default_rng(seed)\n", + " epoch = 0\n", + " while True:\n", + " idx = np.arange(len(dataset))\n", + " if shuffle:\n", + " rng.shuffle(idx)\n", + "\n", + " buf = []\n", + " for i in idx:\n", + " ex = dataset[int(i)]\n", + "\n", + " # build pixel_values HERE (not in ds.with_transform)\n", + " img = _pick_one_image(ex[\"image\"])\n", + " if img is None:\n", + " continue\n", + " arr = np.array(img.convert(\"RGB\"), dtype=np.uint8) # [H,W,3]\n", + " px = preprocess_image(jnp.asarray(arr), IMAGE_SIZE) # [1,S,S,3]\n", + " px0 = np.asarray(px[0], dtype=np.float32) # [S,S,3]\n", + "\n", + " buf.append({\n", + " \"pixel_values\": px0,\n", + " \"question\": ex[\"question\"],\n", + " \"chosen\": ex[\"chosen\"],\n", + " \"rejected\": ex[\"rejected\"],\n", + " })\n", + "\n", + " if len(buf) == batch_size:\n", + " qs = [b[\"question\"] for b in buf]\n", + " chs = [b[\"chosen\"] for b in buf]\n", + " rjs = [b[\"rejected\"] for b in buf]\n", + "\n", + " \n", + " q_tok = [pack_prompt_with_image(gemma_tokenizer.encode(x)) for x in qs]\n", + " ch_tok = [gemma_tokenizer.encode(x) + [EOS] for x in chs]\n", + " rj_tok = [gemma_tokenizer.encode(x) + [EOS] for x in rjs]\n", + "\n", + " Q = np.stack(q_tok, axis=0)\n", + " CH = np.stack([_right_pad_np(ids, TOTAL_GENERATION_STEPS) for ids in ch_tok], axis=0)\n", + " RJ = np.stack([_right_pad_np(ids, TOTAL_GENERATION_STEPS) for ids in rj_tok], axis=0)\n", + "\n", + " PX = np.stack([b[\"pixel_values\"] for b in buf], axis=0).astype(np.float32)\n", + " assert PX.ndim == 4 and PX.shape[-1] == 3, f\"Bad PX shape: {PX.shape}\"\n", + "\n", + " Q_mask = np.stack([_make_mask(ids, PAD) for ids in Q], axis=0)\n", + " CH_mask = np.stack([_make_mask(ids, PAD) for ids in CH], axis=0)\n", + " RJ_mask = np.stack([_make_mask(ids, PAD) for ids in RJ], axis=0)\n", + "\n", + " yield TrainingInput(\n", + " prompt_ids=jnp.asarray(Q),\n", + " prompt_mask=jnp.asarray(Q_mask),\n", + " chosen_ids=jnp.asarray(CH),\n", + " chosen_mask=jnp.asarray(CH_mask),\n", + " rejected_ids=jnp.asarray(RJ),\n", + " rejected_mask=jnp.asarray(RJ_mask),\n", + " pixel_values=jnp.asarray(PX),\n", + " )\n", + " buf = []\n", + "\n", + " epoch += 1\n", + " if epochs is not None and epoch >= epochs:\n", + " break\n", + "\n", + "\n", + "# Smoke one batch\n", + "\n", + "b0 = next(numpy_batches_vlm(ds, batch_size=4))\n", + "print(\"Batch pixels:\", b0.pixel_values.shape, \"| B:\", b0.prompt_ids.shape[0])\n", + "print(\"Batch prompt_ids:\", b0.prompt_ids.shape)\n", + "print(\"Batch prompt_mask:\", b0.prompt_mask.shape)\n", + "print(\"Batch chosen_ids:\", b0.chosen_ids.shape)\n", + "print(\"Batch chosen_mask:\", b0.chosen_mask.shape)\n", + "print(\"Batch rejected_ids:\", b0.rejected_ids.shape)\n", + "print(\"Batch rejected_mask:\", b0.rejected_mask.shape)\n", + "\n", + "print(\"Dataset size:\", len(ds))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XKNb6B_IT09R" + }, + "outputs": [], + "source": [ + "INTERMEDIATE_CKPT_DIR = \"/content/intermediate_ckpt_vlm\"\n", + "os.makedirs(INTERMEDIATE_CKPT_DIR, exist_ok=True)\n", + "HIST_PATH = os.path.join(INTERMEDIATE_CKPT_DIR, \"train_history.json\")\n", + "\n", + "HISTORY = {\n", + " \"step\": [],\n", + " \"loss\": [],\n", + " \"rewards/chosen\": [],\n", + " \"rewards/rejected\": [],\n", + " \"rewards/margin\": [],\n", + " \"rewards/accuracy\": [],\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GBv6XLfLT9sG" + }, + "outputs": [], + "source": [ + "config = DPOTrainingConfig(\n", + " eval_every_n_steps=EVAL_EVERY_N_STEPS,\n", + " max_steps=MAX_STEPS,\n", + " beta=BETA,\n", + " label_smoothing=0.0,\n", + ")\n", + "optimizer = optax.adamw(learning_rate=LEARNING_RATE)\n", + "train_batches = numpy_batches_vlm(\n", + " ds, batch_size=BATCH_SIZE, shuffle=False, seed=42, epochs=EPOCHS\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7PDdv1t8Aom2" + }, + "outputs": [], + "source": [ + "with mesh:\n", + " trainer = DPOTrainer(\n", + " model=lora_gemma,\n", + " ref_model=gemma3,\n", + " optimizer=optimizer,\n", + " training_config=config,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JnfCXXR_AqQa" + }, + "outputs": [], + "source": [ + "_orig_post = getattr(trainer, \"_post_process_train_step\", None)\n", + "\n", + "\n", + "def _patched_post_process_train_step(self, aux):\n", + " if _orig_post is not None:\n", + " _orig_post(aux)\n", + "\n", + " s = int(getattr(self, \"_train_steps\", 0))\n", + "\n", + " loss_val = float(\"nan\")\n", + " bm = getattr(self, \"_buffered_train_metrics\", None)\n", + " if bm is not None and getattr(bm, \"losses\", None):\n", + " try:\n", + " loss_val = float(bm.losses[-1])\n", + " except Exception:\n", + " pass\n", + "\n", + " HISTORY[\"step\"].append(s)\n", + " HISTORY[\"loss\"].append(loss_val)\n", + " HISTORY[\"rewards/chosen\"].append(float(aux[\"rewards/chosen\"]))\n", + " HISTORY[\"rewards/rejected\"].append(float(aux[\"rewards/rejected\"]))\n", + " HISTORY[\"rewards/margin\"].append(float(aux[\"rewards/margin\"]))\n", + " HISTORY[\"rewards/accuracy\"].append(float(aux[\"rewards/accuracy\"]))\n", + "\n", + " if s % EVAL_EVERY_N_STEPS == 0:\n", + " print(\n", + " \"[metric]\"\n", + " f\" step={s} loss={loss_val:.4f} margin={float(aux['rewards/margin']):.4f}\"\n", + " )\n", + "\n", + "\n", + "trainer._post_process_train_step = types.MethodType(\n", + " _patched_post_process_train_step, trainer\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dDdXwCutA9bB" + }, + "outputs": [], + "source": [ + "with mesh:\n", + " trainer.train(train_batches)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9B4qmjLxBH2B" + }, + "outputs": [], + "source": [ + "# optional: persist\n", + "with open(HIST_PATH, \"w\") as f:\n", + " json.dump(HISTORY, f)\n", + "\n", + "\n", + "# --- plotting ---\n", + "def _safe_xy(hist, key):\n", + " x = np.array(hist.get(\"step\", []), dtype=float)\n", + " y = np.array(hist.get(key, []), dtype=float)\n", + " return x, y\n", + "\n", + "\n", + "def moving_average(data, window_size):\n", + " \"\"\"Calculates the moving average of a list or numpy array.\"\"\"\n", + " if len(data) < window_size:\n", + " return (\n", + " data # Return original data if window size is larger than data length\n", + " )\n", + " return np.convolve(data, np.ones(window_size) / window_size, mode=\"valid\")\n", + "\n", + "\n", + "def _plot_series(x, y, title, ylabel, window_size=5):\n", + " if len(x) == 0:\n", + " print(f\"[plot] no data for {title}\")\n", + " return\n", + " plt.figure()\n", + " # Apply moving average\n", + " y_smooth = moving_average(y, window_size)\n", + " x_smooth = x[\n", + " window_size - 1 :\n", + " ] # Adjust x to match the length of the smoothed data\n", + " plt.plot(x_smooth, y_smooth)\n", + " plt.title(title)\n", + " plt.xlabel(\"step\")\n", + " plt.ylabel(ylabel)\n", + " plt.grid(True)\n", + " plt.show()\n", + "\n", + "\n", + "x, y = _safe_xy(HISTORY, \"loss\")\n", + "_plot_series(x, y, \"Training Loss (Smoothed)\", \"loss\")\n", + "\n", + "x, y = _safe_xy(HISTORY, \"rewards/margin\")\n", + "_plot_series(x, y, \"Rewards Margin (chosen - rejected) (Smoothed)\", \"margin\")\n", + "\n", + "x, ch = _safe_xy(HISTORY, \"rewards/chosen\")\n", + "_, rj = _safe_xy(HISTORY, \"rewards/rejected\")\n", + "if len(x):\n", + " plt.figure()\n", + " window_size = 10\n", + " ch_smooth = moving_average(ch, window_size)\n", + " rj_smooth = moving_average(rj, window_size)\n", + " x_smooth = x[window_size - 1 :]\n", + " plt.plot(x_smooth, ch_smooth, label=\"chosen (Smoothed)\")\n", + " plt.plot(x_smooth, rj_smooth, label=\"rejected (Smoothed)\")\n", + " plt.title(\"Chosen vs Rejected Rewards (Smoothed)\")\n", + " plt.xlabel(\"step\")\n", + " plt.ylabel(\"reward\")\n", + " plt.legend()\n", + " plt.grid(True)\n", + " plt.show()\n", + "\n", + "x, y = _safe_xy(HISTORY, \"rewards/accuracy\")\n", + "_plot_series(x, y, \"Rewards Accuracy (Smoothed)\", \"accuracy\")" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V6E1", + "machine_shape": "hm", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tunix/generate/tokenizer_adapter.py b/tunix/generate/tokenizer_adapter.py index edb134ff3..aa0fd0281 100644 --- a/tunix/generate/tokenizer_adapter.py +++ b/tunix/generate/tokenizer_adapter.py @@ -27,6 +27,7 @@ class TokenizerType(enum.Enum): SP: str = 'sp' # sentencepiece tokenizer HF: str = 'hf' # huggingface tokenizer + HFP: str = 'hfp' # huggingface processor NONE: str = 'none' # Represents no tokenizer @@ -42,6 +43,8 @@ def __init__(self, tokenizer: Any): self._tokenizer_type = TokenizerType.SP elif self._is_hf_tokenizer(): self._tokenizer_type = TokenizerType.HF + elif self._is_hf_processor(): + self._tokenizer_type = TokenizerType.HFP elif not missing_methods: self._tokenizer_type = TokenizerType.NONE else: @@ -54,11 +57,16 @@ def __init__(self, tokenizer: Any): f'{missing_methods}.' ) - def encode(self, text: str, **kwargs) -> list[int]: + def encode(self, text: str, **kwargs) -> list[int] | tuple[list[int], Any]: if self._tokenizer_type == TokenizerType.SP: return self._tokenizer.EncodeAsIds(text, **kwargs) elif self._tokenizer_type == TokenizerType.HF: return self._tokenizer.encode(text, **kwargs) + elif self._tokenizer_type == TokenizerType.HFP: + inputs = self._tokenizer(text=text, **kwargs) + if 'images' in kwargs: + return inputs['input_ids'], inputs['pixel_values'] + return inputs['input_ids'] else: return self._tokenizer.encode(text, **kwargs) @@ -67,6 +75,8 @@ def decode(self, ids: list[int], **kwargs) -> str: return self._tokenizer.DecodeIds(ids, **kwargs) elif self._tokenizer_type == TokenizerType.HF: return self._tokenizer.decode(ids, **kwargs) + elif self._tokenizer_type == TokenizerType.HFP: + return self._tokenizer.tokenizer.decode(ids, **kwargs) else: return self._tokenizer.decode(ids, **kwargs) @@ -75,6 +85,8 @@ def bos_id(self) -> int: return self._tokenizer.bos_id() elif self._tokenizer_type == TokenizerType.HF: return self._tokenizer.bos_token_id + elif self._tokenizer_type == TokenizerType.HFP: + return self._tokenizer.tokenizer.bos_token_id else: return self._tokenizer.bos_id() @@ -83,6 +95,8 @@ def eos_id(self) -> int: return self._tokenizer.eos_id() elif self._tokenizer_type == TokenizerType.HF: return self._tokenizer.eos_token_id + elif self._tokenizer_type == TokenizerType.HFP: + return self._tokenizer.tokenizer.eos_token_id else: return self._tokenizer.eos_id() @@ -98,6 +112,8 @@ def pad_id(self) -> int: if self._tokenizer.pad_token_id is None: self._tokenizer.pad_token = self._tokenizer.eos_token return self._tokenizer.pad_token_id + elif self._tokenizer_type == TokenizerType.HFP: + return self._tokenizer.tokenizer.pad_token_id else: return self._tokenizer.pad_id() @@ -124,12 +140,19 @@ def _is_hf_tokenizer(self) -> bool: baseclass.__module__ + '.' + baseclass.__name__ for baseclass in baseclasses ] - if ( + return ( 'transformers.tokenization_utils_base.PreTrainedTokenizerBase' in baseclass_names - ): - return True - return False + ) + + def _is_hf_processor(self) -> bool: + """Checks if the tokenizer is a huggingface Processor.""" + baseclasses = inspect.getmro(type(self._tokenizer)) + baseclass_names = [ + baseclass.__module__ + '.' + baseclass.__name__ + for baseclass in baseclasses + ] + return 'transformers.processing_utils.ProcessorMixin' in baseclass_names @property def tokenizer(self) -> Any: diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index 0c8963b55..3c329c7a0 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -25,6 +25,7 @@ from flax import nnx import jax from jax import lax +import jax.image as jimg import jax.numpy as jnp import numpy as np diff --git a/tunix/models/gemma3/model.py b/tunix/models/gemma3/model.py index 7da83b02f..303c4266e 100644 --- a/tunix/models/gemma3/model.py +++ b/tunix/models/gemma3/model.py @@ -98,6 +98,7 @@ class ModelConfig: num_heads: int head_dim: int num_kv_heads: int + multimodal: bool = False sliding_window_size: int | None = None local_base_frequency: int = 10_000 global_base_frequency: int = 10_000 @@ -877,11 +878,86 @@ def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: return normed_inputs +class MultimodalProjector(nnx.Module): + """Image soft token pooling + projection.""" + + IMAGE_SOFT_TOKEN_ID: int = 262144 + + def __init__( + self, + vision_embed_dim: int, + text_embed_dim: int, + patches_per_side: int, + output_tokens_per_side=16, + *, + rngs: nnx.Rngs, + shd_config: ShardingConfig = ShardingConfig.get_default_sharding(), + ): + self.patches_per_side = patches_per_side + self.output_tokens_per_side = output_tokens_per_side + self.output_tokens_total = output_tokens_per_side * output_tokens_per_side + self.kernel_size = patches_per_side // output_tokens_per_side + + self.mm_soft_emb_norm = RMSNorm( + vision_embed_dim, + rngs=rngs, + sharding=shd_config.rms_norm_weight, + ) + self.mm_input_projection = nnx.Linear( + in_features=vision_embed_dim, + out_features=text_embed_dim, + use_bias=False, + rngs=rngs, + kernel_init=nnx.with_partitioning( + nnx.initializers.zeros_init(), shd_config.ffw_weight_df + ), + ) + + @jax.named_scope('multimodal_projector') + def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: + B, _, D = x.shape + x = x.reshape(B, self.patches_per_side, self.patches_per_side, D) + x = nnx.avg_pool( + x, + window_shape=(self.kernel_size, self.kernel_size), + strides=(self.kernel_size, self.kernel_size), + ) + x = x.reshape(B, self.output_tokens_total, D) + x = self.mm_soft_emb_norm(x) + x = self.mm_input_projection(x) + return x + + class Gemma3(nnx.Module): - """Gemma3 transformer.""" + """Gemma transformer.""" def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs): self.config = config + + if config.multimodal: + from tunix.models.siglip.model import SigLIPConfig, SigLIPEngine # pylint: disable=g-import-not-at-top + + self.siglip = SigLIPEngine( + cfg=SigLIPConfig( + image_size=896, + patch_size=14, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_hidden_dim=4304, + use_cls_token=False, + use_abs_pos_emb=True, + ), + rngs=rngs, + ) + self.projector = MultimodalProjector( + 1152, + config.embed_dim, + 64, + rngs=rngs, + shd_config=config.shd_config, + ) + self.embedder = Embedder( vocab_size=config.num_embed, embed_dim=config.embed_dim, @@ -927,6 +1003,7 @@ def __call__( positions: jaxtyping.Array, # [B, L] cache: Cache | None, # (sequence length L') attention_mask: jaxtyping.Array, # [B, L, L'] + pixel_values: jaxtyping.Array | None = None, # [B, H, W, C] output_hidden_states: bool = False, ) -> tuple[jaxtyping.Array, Cache | None]: """Transformer forward pass. @@ -934,11 +1011,18 @@ def __call__( You can run this forward pass two ways: with or without an attention kv cache. + Note: for multimodal (image + text) inputs: last_tokens is expected to be + already preprocessed to contain exactly 256 (id=262144) + per tokenized input, and attention_mask is expected to already have been + adjusted for image tokens, i.e. image tokens attend to all tokens in the + (same) image bidirectionally in addition to attending to all previous tokens + Args: last_tokens: input sequence of tokens. positions: input absolute positions. cache: Attention KV cache or None. attention_mask: transformer input mask. + pixel_values: (preprocessed) images for multimodal, None for text-only. output_hidden_states: whether to output the hidden states. Returns: @@ -947,8 +1031,29 @@ def __call__( predicted_logits: output logits predicted by the model new_cache: updated cache if the input cache is not None, None elsewhere. """ + new_cache = None if cache is None else {} - x = self.embedder.encode(last_tokens) + + if self.config.multimodal: + assert pixel_values is not None + image_mask = last_tokens == 262144 # 262144: + + vision_outputs = self.siglip(pixel_values) # B, 4096, 1152 + image_features = self.projector(vision_outputs) # B, 256, embed_dim + + last_tokens = jnp.where(image_mask, 0, last_tokens) + x = self.embedder.encode(last_tokens) + image_features = image_features.astype(x.dtype) + + # Write image features to embedded input + idx = jnp.cumsum(image_mask, axis=1) - 1 + idx = jnp.where(image_mask, idx, 0) + gathered = jnp.take_along_axis(image_features, idx[..., None], axis=1) + x = jnp.where(image_mask[..., None], gathered, x) + + else: + x = self.embedder.encode(last_tokens) + for i, layer in enumerate(self.layers): layer_name = f'layer_{i}' layer_cache = cache[layer_name] if cache else None @@ -989,7 +1094,36 @@ def get_model_input(self): (dummy_batch_size, 1, dummy_seq_len), dtype=jnp.bool ), } + + @staticmethod + def make_mm_attention_mask( + input_ids: jaxtyping.Array, # [B, L] + input_mask: jaxtyping.Array, # [B, L] (1 for valid tokens, 0 for pad) + ) -> jaxtyping.Array: + """Builds Gemma3 multimodal attention mask. + + - Base causal attention + - Text can attend to image tokens + - Image tokens attend bidirectionally to other image tokens + - Padding respected + """ + + # Base causal mask (already handles pad keys) + from tunix.rl import common # local import avoids circular deps + + attn = common.make_causal_attn_mask(input_mask) # [B, L, L] + + image_mask = (input_ids == Gemma3.IMAGE_SOFT_TOKEN_ID) # [B, L] + + # Allow any query to attend to image keys + attn = attn | image_mask[:, None, :] + + # Fully open image <-> image attention + attn = attn | (image_mask[:, :, None] & image_mask[:, None, :]) + + + return attn @property def embed_dim(self) -> int: return self.embedder.embed_dim diff --git a/tunix/models/gemma3/params.py b/tunix/models/gemma3/params.py index 15cbc07be..9d2c3eecc 100644 --- a/tunix/models/gemma3/params.py +++ b/tunix/models/gemma3/params.py @@ -59,7 +59,7 @@ def create_model_from_checkpoint( lambda: model_lib.Gemma3(model_config, rngs=nnx.Rngs(0)) ) params = ocp.StandardCheckpointer().restore(checkpoint_path) - params = map_from_upstream_checkpoint(params) + params = map_from_upstream_checkpoint(params, multimodal=model_config.multimodal) if mesh is not None: params = jax.tree.map( lambda x, shd: jnp.asarray(x, device=shd, dtype=dtype), @@ -88,7 +88,9 @@ def create_tokenizer( return spm_processor -def map_from_upstream_checkpoint(params, model_type: str = 'gemma3'): +def map_from_upstream_checkpoint( + params, model_type: str = 'gemma3', multimodal: bool = False +): """Map from upstream checkpoint to our implementation.""" # From: # @@ -127,13 +129,70 @@ def map_from_upstream_checkpoint(params, model_type: str = 'gemma3'): module_path, param_name = key_path module_path = module_path.split('/')[1:] # Remove the leading 'transformer' if module_path[0] == 'siglip_encoder': - continue # We don't support MM input yet. - if module_path[0] == 'embedder': - if len(module_path) > 1 and module_path[1].startswith('mm_'): - continue # We don't support MM input yet. + if not multimodal: + continue + if param_name == 'pos_embedding': + new_params[('siglip', 'pos_embed')] = value + continue + elif module_path[1] == 'embedding': + new_params[('siglip', 'patch', 'proj', param_name)] = value + continue + elif module_path[2] == 'encoder_norm': + new_params[('siglip', 'norm', param_name)] = value + continue + + assert module_path[2].startswith('encoderblock_') + siglip_layer = ( + 'siglip', + 'blocks', + int(module_path[2].removeprefix('encoderblock_')), + ) + + if module_path[3] == 'LayerNorm_0': + new_params[(*siglip_layer, 'ln1', param_name)] = value + elif module_path[3] == 'LayerNorm_1': + new_params[(*siglip_layer, 'ln2', param_name)] = value + elif module_path[3] == 'MultiHeadDotProductAttention_0': + if module_path[4] == 'out': + if value.ndim == 3: + value = value.reshape(-1, value.shape[-1]) + else: + value = value.reshape(-1) + new_params[(*siglip_layer, 'attn', 'o', param_name)] = value + else: + if value.ndim == 3: + value = value.reshape(value.shape[0], -1) + else: + value = value.reshape(-1) + if module_path[4] == 'query': + new_params[(*siglip_layer, 'attn', 'q', param_name)] = value + elif module_path[4] == 'key': + new_params[(*siglip_layer, 'attn', 'k', param_name)] = value + else: + assert module_path[4] == 'value' + new_params[(*siglip_layer, 'attn', 'v', param_name)] = value + elif module_path[3:] == ['MlpBlock_0', 'Dense_0']: + new_params[(*siglip_layer, 'mlp', 'fc1', param_name)] = value + else: + assert module_path[3:] == ['MlpBlock_0', 'Dense_1'] + new_params[(*siglip_layer, 'mlp', 'fc2', param_name)] = value + continue + + if ( + module_path[0] == 'embedder' + and len(module_path) > 1 + and module_path[1].startswith('mm_') + ): + if multimodal: + if module_path[1] == 'mm_soft_embedding_norm': + new_params[('projector', 'mm_soft_emb_norm', param_name)] = value + elif module_path[1] == 'mm_input_projection': + new_params[('projector', 'mm_input_projection', 'kernel')] = value + continue if module_path[0] in ('embedder', 'final_norm'): new_params[(module_path[0], param_name)] = value continue + # module_path should now look like ('layer_0', 'attn', '_key_norm') layer_idx = ('layers', int(module_path[0].removeprefix('layer_'))) if module_path[1:] == ['mlp', 'gating_einsum']: diff --git a/tunix/models/gemma3/params_safetensors.py b/tunix/models/gemma3/params_safetensors.py index 0177869e8..3a6b92664 100644 --- a/tunix/models/gemma3/params_safetensors.py +++ b/tunix/models/gemma3/params_safetensors.py @@ -14,89 +14,164 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig): """Mapping of torch_keys to (nnx_keys, (permute_rule, reshape_rule)).""" return { - r"(?:language_model\.)?model\.embed_tokens\.weight": ( - "embedder.input_embedding", None + r"(language_model\.)?model\.embed_tokens\.weight": ( + "embedder.input_embedding", + None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.q_proj\.weight": ( - r"tmp.layers.\1.attn.q", + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.q_proj\.weight": ( + r"tmp.layers.\2.attn.q", ((1, 0), (cfg.embed_dim, cfg.num_heads, cfg.head_dim)), ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.k_proj\.weight": ( - r"tmp.layers.\1.attn.k", + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.k_proj\.weight": ( + r"tmp.layers.\2.attn.k", ((1, 0), (cfg.embed_dim, cfg.num_kv_heads, cfg.head_dim)), ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.v_proj\.weight": ( - r"tmp.layers.\1.attn.v", + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.v_proj\.weight": ( + r"tmp.layers.\2.attn.v", ((1, 0), (cfg.embed_dim, cfg.num_kv_heads, cfg.head_dim)), ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.o_proj\.weight": ( - r"layers.\1.attn.attn_vec_einsum.w", + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.o_proj\.weight": ( + r"layers.\2.attn.attn_vec_einsum.w", ((1, 0), (cfg.num_heads, cfg.head_dim, cfg.embed_dim)), ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.mlp\.gate_proj\.weight": ( - r"layers.\1.mlp.gate_proj.kernel", + r"(language_model\.)?model\.layers\.([0-9]+)\.mlp\.gate_proj\.weight": ( + r"layers.\2.mlp.gate_proj.kernel", ((1, 0), None), ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.mlp\.up_proj\.weight": ( - r"layers.\1.mlp.up_proj.kernel", + r"(language_model\.)?model\.layers\.([0-9]+)\.mlp\.up_proj\.weight": ( + r"layers.\2.mlp.up_proj.kernel", ((1, 0), None), ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.mlp\.down_proj\.weight": ( - r"layers.\1.mlp.down_proj.kernel", + r"(language_model\.)?model\.layers\.([0-9]+)\.mlp\.down_proj\.weight": ( + r"layers.\2.mlp.down_proj.kernel", ((1, 0), None), ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.input_layernorm\.weight": ( - r"layers.\1.pre_attention_norm.scale", + r"(language_model\.)?model\.layers\.([0-9]+)\.input_layernorm\.weight": ( + r"layers.\2.pre_attention_norm.scale", None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.post_attention_layernorm\.weight": ( - r"layers.\1.post_attention_norm.scale", + r"(language_model\.)?model\.layers\.([0-9]+)\.post_attention_layernorm\.weight": ( + r"layers.\2.post_attention_norm.scale", None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.(post_feedforward_layernorm|post_ffn_layernorm|post_ffw_layernorm)\.weight": ( - r"layers.\1.post_ffw_norm.scale", + r"(language_model\.)?model\.layers\.([0-9]+)\.(post_feedforward_layernorm|post_ffn_layernorm|post_ffw_layernorm)\.weight": ( + r"layers.\2.post_ffw_norm.scale", None, ), - r"(?:language_model\.)?model\.norm\.weight": ("final_norm.scale", None), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.q_norm\.weight": ( - r"layers.\1.attn._query_norm.scale", + r"(language_model\.)?model\.norm\.weight": ("final_norm.scale", None), + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.q_norm\.weight": ( + r"layers.\2.attn._query_norm.scale", None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.k_norm\.weight": ( - r"layers.\1.attn._key_norm.scale", + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.k_norm\.weight": ( + r"layers.\2.attn._key_norm.scale", None, ), r"lm_head\.weight": ("unused.lm_head.weight", None), r"lm_head\.bias": ("unused.lm_head.bias", None), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.(q_proj|k_proj|v_proj|o_proj)\.bias": ( - r"unused.layers.\1.attn.\2.bias", + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.(q_proj|k_proj|v_proj|o_proj)\.bias": ( + r"unused.layers.\2.attn.\3.bias", + None, + ), + r"(language_model\.)?model\.layers\.([0-9]+)\.input_layernorm\.bias": ( + r"unused.layers.\2.input_layernorm.bias", + None, + ), + r"(language_model\.)?model\.layers\.([0-9]+)\.post_attention_layernorm\.bias": ( + r"unused.layers.\2.post_attention_layernorm.bias", + None, + ), + r"(language_model\.)?model\.rotary_emb\..*": ("unused.rotary_emb", None), + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.rotary_emb\..*": ( + r"unused.layers.\2.attn.rotary_emb", + None, + ), + r"(language_model\.)?model\.layers\.([0-9]+)\.self_attn\.qkv_proj\.weight": ( + r"unused.layers.\2.attn.qkv_proj.weight", + None, + ), + r"(language_model\.)?model\.layers\.([0-9]+)\.pre_feedforward_layernorm\.weight": ( + r"layers.\2.pre_ffw_norm.scale", + None, + ), + r"(language_model\.)?model\.layers\.([0-9]+)\.(pre_ffn_layernorm|pre_ffw_layernorm)\.weight": ( + r"layers.\2.pre_ffw_norm.scale", + None, + ), + r"multi_modal_projector.mm_input_projection_weight": ( + r"projector.mm_input_projection.kernel", + None, + ), + r"multi_modal_projector.mm_soft_emb_norm.weight": ( + r"projector.mm_soft_emb_norm.scale", + None, + ), + r"vision_tower\.vision_model\.embeddings\.patch_embedding\.bias": ( + r"siglip.patch.proj.bias", + None, + ), + r"vision_tower\.vision_model\.embeddings\.patch_embedding\.weight": ( + r"siglip.patch.proj.kernel", + ((2, 3, 1, 0), None), + ), + r"vision_tower\.vision_model\.embeddings\.position_embedding\.weight": ( + r"siglip.pos_embed", + (None, (1, -1, 1152)), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.layer_norm1\.bias": ( + r"siglip.blocks.\1.ln1.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.layer_norm1\.weight": ( + r"siglip.blocks.\1.ln1.scale", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.layer_norm2\.bias": ( + r"siglip.blocks.\1.ln2.bias", None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.input_layernorm\.bias": ( - r"unused.layers.\1.input_layernorm.bias", + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.layer_norm2\.weight": ( + r"siglip.blocks.\1.ln2.scale", None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.post_attention_layernorm\.bias": ( - r"unused.layers.\1.post_attention_layernorm.bias", + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc1\.bias": ( + r"siglip.blocks.\1.mlp.fc1.bias", None, ), - r"(?:language_model\.)?model\.rotary_emb\..*": ( - "unused.rotary_emb", None + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc1\.weight": ( + r"siglip.blocks.\1.mlp.fc1.kernel", + ((1, 0), None), ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.rotary_emb\..*": ( - r"unused.layers.\1.attn.rotary_emb", + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc2\.bias": ( + r"siglip.blocks.\1.mlp.fc2.bias", None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.self_attn\.qkv_proj\.weight": ( - r"unused.layers.\1.attn.qkv_proj.weight", + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc2\.weight": ( + r"siglip.blocks.\1.mlp.fc2.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.([qkv])_proj\.bias": ( + r"siglip.blocks.\1.attn.\2.bias", None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.pre_feedforward_layernorm\.weight": ( - r"layers.\1.pre_ffw_norm.scale", + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.([qkv])_proj\.weight": ( + r"siglip.blocks.\1.attn.\2.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.bias": ( + r"siglip.blocks.\1.attn.o.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.weight": ( + r"siglip.blocks.\1.attn.o.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.post_layernorm\.bias": ( + r"siglip.norm.bias", None, ), - r"(?:language_model\.)?model\.layers\.([0-9]+)\.(pre_ffn_layernorm|pre_ffw_layernorm)\.weight": ( - r"layers.\1.pre_ffw_norm.scale", + r"vision_tower\.vision_model\.post_layernorm\.weight": ( + r"siglip.norm.scale", None, ), } diff --git a/tunix/models/siglip/model.py b/tunix/models/siglip/model.py new file mode 100644 index 000000000..b97f3bb70 --- /dev/null +++ b/tunix/models/siglip/model.py @@ -0,0 +1,306 @@ +"""SigLIP vision encoder (ViT-style) implemented with Flax NNX.""" + +from __future__ import annotations + +import dataclasses +from typing import Tuple + +from flax import nnx +import jax +from jax.interpreters import pxla +import jax.numpy as jnp +import jax.sharding as shd +import jaxtyping + + +def shard(x: jnp.ndarray, s: Tuple[str | None, ...]): + """Apply named sharding if a mesh is present; no-op on CPU.""" + mesh = pxla.thread_resources.env.physical_mesh + if mesh.empty or jax.devices()[0].platform == "cpu": + return x + return jax.lax.with_sharding_constraint( + x, shd.NamedSharding(mesh, shd.PartitionSpec(*s)) + ) + + +@dataclasses.dataclass(slots=True, frozen=True) +class ShardingConfig: + """Sharding configuration for SigLIP encoder.""" + + # weight shardings + patch_kernel_hwci: Tuple[str | None, ...] # Conv: [H, W, C, D] + attn_qkvo_dd: Tuple[str | None, ...] # Linear: [D, D] + mlp_df: Tuple[str | None, ...] # Linear: [D, F] + mlp_fd: Tuple[str | None, ...] # Linear: [F, D] + ln_weight: Tuple[str | None, ...] # LayerNorm scale/bias + + # activations + act_bnd: Tuple[str | None, ...] # [B, N, D] + act_bnf: Tuple[str | None, ...] # [B, N, F] + act_bnhd: Tuple[str | None, ...] # [B, N, H, Dh] + + @staticmethod + def get_default_sharding(is_sampling: bool = False): + fsdp = "fsdp" if not is_sampling else None + return ShardingConfig( + patch_kernel_hwci=(None, None, None, "tp"), + attn_qkvo_dd=("tp", fsdp), + mlp_df=("tp", fsdp), + mlp_fd=("tp", fsdp), + ln_weight=("tp",), + act_bnd=("fsdp", None, None if is_sampling else "tp"), + act_bnf=("fsdp", None, "tp"), + act_bnhd=("fsdp", None, "tp", None), + ) + + +@dataclasses.dataclass(frozen=True) +class SigLIPConfig: + image_size: int = 224 + patch_size: int = 16 + embed_dim: int = 768 + depth: int = 12 + num_heads: int = 12 + mlp_ratio: float = 4.0 + # NEW: explicit hidden size if provided + mlp_hidden_dim: int | None = None + drop_rate: float = 0.0 + attn_drop_rate: float = 0.0 + use_cls_token: bool = False + use_abs_pos_emb: bool = True + shd_config: ShardingConfig = ShardingConfig.get_default_sharding() + + @property + def head_dim(self) -> int: + if self.embed_dim % self.num_heads != 0: + raise ValueError("embed_dim must be divisible by num_heads") + return self.embed_dim // self.num_heads + + @property + def num_patches(self) -> int: + g = self.image_size // self.patch_size + return g * g + + @classmethod + def so400m_patch14_384(cls): + return cls( + image_size=384, + patch_size=14, + embed_dim=1152, + depth=27, + num_heads=16, + mlp_ratio=3.0, # keep whatever; it’ll be ignored + mlp_hidden_dim=4304, # THIS drives the shapes + use_cls_token=False, + use_abs_pos_emb=True, + shd_config=ShardingConfig.get_default_sharding(), + ) + + +class PatchEmbed(nnx.Module): + """Patchify with a Conv2D (stride=patch_size), then flatten to tokens.""" + + def __init__(self, cfg: SigLIPConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + # NNX Conv uses [H,W,C_in,C_out] kernel layout by default. + self.proj = nnx.Conv( + in_features=3, + out_features=cfg.embed_dim, + kernel_size=(cfg.patch_size, cfg.patch_size), + strides=(cfg.patch_size, cfg.patch_size), + padding="VALID", + kernel_init=nnx.with_partitioning( + nnx.initializers.lecun_normal(), cfg.shd_config.patch_kernel_hwci + ), + bias_init=nnx.initializers.zeros_init(), + rngs=rngs, + ) + + @jax.named_scope("patch_embed") + def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: + # x: [B,H,W,3] -> conv -> [B,H/P,W/P,D] -> [B,N,D] + x = self.proj(x) + b, h, w, d = x.shape + x = x.reshape(b, h * w, d) + x = shard(x, self.cfg.shd_config.act_bnd) + return x + + +class MLP(nnx.Module): + """Standard ViT MLP block with GELU.""" + + def __init__(self, cfg: SigLIPConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + hidden = cfg.mlp_hidden_dim or int(cfg.embed_dim * cfg.mlp_ratio) + self.fc1 = nnx.Linear( + in_features=cfg.embed_dim, + out_features=hidden, + use_bias=True, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), cfg.shd_config.mlp_df + ), + bias_init=nnx.initializers.zeros_init(), + rngs=rngs, + ) + self.fc2 = nnx.Linear( + in_features=hidden, + out_features=cfg.embed_dim, + use_bias=True, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), cfg.shd_config.mlp_fd + ), + bias_init=nnx.initializers.zeros_init(), + rngs=rngs, + ) + + @jax.named_scope("mlp") + def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: + h = jax.nn.gelu(self.fc1(x)) + h = shard(h, self.cfg.shd_config.act_bnf) + return self.fc2(h) + + +class MultiHeadSelfAttention(nnx.Module): + """MHA with separate Q/K/V projections and output projection.""" + + def __init__(self, cfg: SigLIPConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + d = cfg.embed_dim + self.q = nnx.Linear( + in_features=d, + out_features=d, + use_bias=True, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), cfg.shd_config.attn_qkvo_dd + ), + rngs=rngs, + ) + self.k = nnx.Linear( + in_features=d, + out_features=d, + use_bias=True, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), cfg.shd_config.attn_qkvo_dd + ), + rngs=rngs, + ) + self.v = nnx.Linear( + in_features=d, + out_features=d, + use_bias=True, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), cfg.shd_config.attn_qkvo_dd + ), + rngs=rngs, + ) + self.o = nnx.Linear( + in_features=d, + out_features=d, + use_bias=True, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), cfg.shd_config.attn_qkvo_dd + ), + rngs=rngs, + ) + self.scale = (cfg.head_dim) ** -0.5 + + @jax.named_scope("mhsa") + def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: + b, n, d = x.shape + h = self.cfg.num_heads + dh = self.cfg.head_dim + + q = self.q(x).reshape(b, n, h, dh) + k = self.k(x).reshape(b, n, h, dh) + v = self.v(x).reshape(b, n, h, dh) + + q = shard(q, self.cfg.shd_config.act_bnhd) + k = shard(k, self.cfg.shd_config.act_bnhd) + v = shard(v, self.cfg.shd_config.act_bnhd) + + attn = jnp.einsum("bnhd,bmhd->bhnm", q * self.scale, k) # [B,H,N,N] + attn = jax.nn.softmax(attn, axis=-1) + out = jnp.einsum("bhnm,bmhd->bnhd", attn, v).reshape(b, n, d) + out = self.o(out) + out = shard(out, self.cfg.shd_config.act_bnd) + return out + + +class EncoderBlock(nnx.Module): + """(LN -> MHA -> residual) + (LN -> MLP -> residual).""" + + def __init__(self, cfg: SigLIPConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + self.ln1 = nnx.LayerNorm( + cfg.embed_dim, use_bias=True, param_dtype=jnp.float32, rngs=rngs + ) + self.attn = MultiHeadSelfAttention(cfg, rngs=rngs) + self.ln2 = nnx.LayerNorm( + cfg.embed_dim, use_bias=True, param_dtype=jnp.float32, rngs=rngs + ) + self.mlp = MLP(cfg, rngs=rngs) + + @jax.named_scope("encoder_block") + def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + + +class SigLIPEngine(nnx.Module): + + def __init__(self, cfg: SigLIPConfig, *, rngs: nnx.Rngs): + self.cfg = cfg + self.patch = PatchEmbed(cfg, rngs=rngs) + self.blocks = nnx.List( + [EncoderBlock(cfg, rngs=rngs) for _ in range(cfg.depth)] + ) + self.norm = nnx.LayerNorm( + cfg.embed_dim, use_bias=True, param_dtype=jnp.float32, rngs=rngs + ) + + # Create params only if enabled; do NOT pre-assign None. + if cfg.use_abs_pos_emb: + pe_shape = ( + 1, + cfg.num_patches + (1 if cfg.use_cls_token else 0), + cfg.embed_dim, + ) + self.pos_embed = nnx.Param( + jax.random.normal(rngs.params(), pe_shape) * 0.02 + ) + + if cfg.use_cls_token: + self.cls_token = nnx.Param( + jax.random.normal(rngs.params(), (1, 1, cfg.embed_dim)) * 0.02 + ) + + def get_model_input(self): + """Dummy input (compatible with sharding) — used by Qwix/rollout.""" + b = 2 + return { + "images": jnp.ones( + (b, self.cfg.image_size, self.cfg.image_size, 3), jnp.float32 + ) + } + + @jax.named_scope("siglip_encoder") + def __call__(self, images): + x = self.patch(images) # [B, N, D] + b, n, d = x.shape + + if hasattr(self, "cls_token"): + cls = jnp.tile(self.cls_token.value, (b, 1, 1)) + x = jnp.concatenate([cls, x], axis=1) + + if hasattr(self, "pos_embed"): + x = x + self.pos_embed.value[:, : x.shape[1], :] + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + if hasattr(self, "cls_token"): + return x[:, 1:, :] + return x diff --git a/tunix/models/siglip/params.py b/tunix/models/siglip/params.py new file mode 100644 index 000000000..65fd96e95 --- /dev/null +++ b/tunix/models/siglip/params.py @@ -0,0 +1,267 @@ +# Copyright ... +# Licensed under the Apache License, Version 2.0 + +"""Checkpoint loader for SigLIP encoder.""" + +from __future__ import annotations + +import json +import re +from typing import Any, Dict, Tuple + +from etils import epath +from flax import nnx +import jax +import jax.numpy as jnp +import safetensors.numpy as stnp # or safetensors.flax if you prefer +from tunix.models.siglip import model as model_lib + + +def _get_key_and_transform_mapping(cfg: model_lib.SigLIPConfig): + D = cfg.embed_dim + F = int(cfg.embed_dim * cfg.mlp_ratio) + + return { + # Patch projection (Conv2D as linear projection for patches) + r"vision_model\.embeddings\.patch_embedding\.projection\.weight": ( + "patch.proj.kernel", + ((2, 3, 1, 0), None), # [Co,Ci,Kh,Kw] -> [Kh,Kw,Ci,Co] + ), + r"vision_model\.embeddings\.patch_embedding\.projection\.bias": ( + "patch.proj.bias", + (None, None), + ), + # Encoder layer norms + r"vision_model\.encoder\.layers\.([0-9]+)\.layernorm_before\.weight": ( + r"blocks.\1.ln1.scale", + (None, None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.layernorm_before\.bias": ( + r"blocks.\1.ln1.bias", + (None, None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.layernorm_after\.weight": ( + r"blocks.\1.ln2.scale", + (None, None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.layernorm_after\.bias": ( + r"blocks.\1.ln2.bias", + (None, None), + ), + # Attention proj (separate q/k/v/out) + r"vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.weight": ( + r"blocks.\1.attn.q.kernel", + ((1, 0), None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.bias": ( + r"blocks.\1.attn.q.bias", + (None, None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.weight": ( + r"blocks.\1.attn.k.kernel", + ((1, 0), None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.bias": ( + r"blocks.\1.attn.k.bias", + (None, None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.weight": ( + r"blocks.\1.attn.v.kernel", + ((1, 0), None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.bias": ( + r"blocks.\1.attn.v.bias", + (None, None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.weight": ( + r"blocks.\1.attn.o.kernel", + ((1, 0), None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.bias": ( + r"blocks.\1.attn.o.bias", + (None, None), + ), + # MLP (GELU) + r"vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc1\.weight": ( + r"blocks.\1.mlp.fc1.kernel", + ((1, 0), None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc1\.bias": ( + r"blocks.\1.mlp.fc1.bias", + (None, None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc2\.weight": ( + r"blocks.\1.mlp.fc2.kernel", + ((1, 0), None), + ), + r"vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc2\.bias": ( + r"blocks.\1.mlp.fc2.bias", + (None, None), + ), + # Final norm (Transformers often: `vision_model.post_layernorm.*` or `vision_model.layernorm.*`) + r"vision_model\.(post_layernorm|layernorm)\.weight": ( + "norm.scale", + (None, None), + ), + r"vision_model\.(post_layernorm|layernorm)\.bias": ( + "norm.bias", + (None, None), + ), + # (Some SigLIP variants may include absolute pos embed; if present and shapes match, map here) + r"vision_model\.embeddings\.position_embedding(?:\.weight)?": ( + "pos_embed", + (None, None), + ), + # If CLS token exists in the checkpoint (many SigLIP variants don’t use it): + r"vision_model\.embeddings\.cls_token": ("cls_token", (None, None)), + } + + +def _torch_key_to_jax_key(mapping, source_key): + subs = [ + (re.sub(pat, repl, source_key), transform) + for pat, (repl, transform) in mapping.items() + if re.match(pat, source_key) + ] + if len(subs) != 1: + raise KeyError(f"Ambiguous or missing mapping for: {source_key} -> {subs}") + return subs[0] + + +def _transpose_and_reshape(x, transform): + if transform is None: + return x + permute, reshape = transform + if permute: + x = x.transpose(permute) + if reshape: + x = x.reshape(reshape) + return x + + +def _siglip_cfg_from_hf(dir_path: str) -> model_lib.SigLIPConfig: + """Read HF config.json (vision_config) and build a SigLIPConfig.""" + cfg_path = epath.Path(dir_path).expanduser() / "config.json" + data = json.loads(cfg_path.read_text()) + vc = data.get("vision_config", {}) # transformers puts vision params here + + image_size = int(vc.get("image_size", 384)) + patch_size = int(vc.get("patch_size", 14)) + hidden_size = int(vc.get("hidden_size", 768)) + intermediate_size = int(vc.get("intermediate_size", 3072)) + num_layers = int(vc.get("num_hidden_layers", 12)) + num_heads = int(vc.get("num_attention_heads", 12)) + mlp_ratio = float(intermediate_size) / float(hidden_size) + + return model_lib.SigLIPConfig( + image_size=image_size, + patch_size=patch_size, + embed_dim=hidden_size, + depth=num_layers, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + # keep defaults: abs pos = True, no cls token, default sharding + ) + + +def create_model_from_safe_tensors( + file_dir: str, + config: model_lib.SigLIPConfig | None = None, + mesh: jax.sharding.Mesh | None = None, +) -> model_lib.SigLIPEngine: + """Load SigLIP encoder from a folder of safetensors (HF-style).""" + + files = list(epath.Path(file_dir).expanduser().glob("*.safetensors")) + if not files: + raise ValueError(f"No safetensors in {file_dir}") + + # 1) Infer config if not provided + if config is None: + config = _siglip_cfg_from_hf(file_dir) + + # 2) Build a CONCRETE model/state (not eval_shape) so params are real arrays + enc_concrete = model_lib.SigLIPEngine(config, rngs=nnx.Rngs(params=0)) + graph_def, state = nnx.split(enc_concrete) + state_dict = state.to_pure_dict() + + # Optional: get sharding PartitionSpecs for later + pspecs = nnx.get_partition_spec(state) if mesh is not None else None + + key_map = _get_key_and_transform_mapping(config) + + def path_to_key(path): + parts = [] + for p in path: + parts.append(str(p.key if hasattr(p, "key") else p)) + return ".".join(parts) + + # Helpful diagnostics + loaded_keys = set() + missing_param_keys = set() + + for f in files: + current_file_tensors: Dict[str, jnp.ndarray] = {} + with stnp.safe_open(f, framework="numpy") as sf: + for torch_key in sf.keys(): + arr = sf.get_tensor(torch_key) + try: + jax_key_mapped, transform = _torch_key_to_jax_key(key_map, torch_key) + except KeyError: + # Skip unknown keys (e.g., text tower, optimizer states) + continue + arr = _transpose_and_reshape(arr, transform) + current_file_tensors[jax_key_mapped] = jax.numpy.array(arr) + + def update_tensor(path, param): + k = path_to_key(path) + if k in current_file_tensors: + v = current_file_tensors[k] + + # ---- shape fixups (HF -> NNX) ---- + # pos_embed: [N,D] -> [1,N,D] + if ( + v.ndim + 1 == param.shape.__len__() + and getattr(param, "shape", None) is not None + and param.shape[0] == 1 + and tuple(v.shape) == tuple(param.shape[1:]) + ): + v = v[None, ...] + # cls_token: [1,D] -> [1,1,D] + if k.endswith("cls_token") and v.ndim == 2 and len(param.shape) == 3: + if ( + param.shape[0] == 1 + and param.shape[1] == 1 + and v.shape[-1] == param.shape[-1] + ): + v = v[:, None, :] + + if v.shape != param.shape: + raise ValueError( + f"Shape mismatch for {k}: got {v.shape}, expected {param.shape}" + ) + + loaded_keys.add(k) + return v + # Not found in safetensors — keep initialized param + missing_param_keys.add(k) + return param + + state_dict = jax.tree.map_with_path(update_tensor, state_dict) + + # Re-merge concrete state + enc = nnx.merge(graph_def, state_dict) + + # If you want sharding, apply after merge + if mesh is not None: + with mesh: + st = nnx.state(enc) + st = jax.lax.with_sharding_constraint(st, pspecs) + nnx.update(enc, st) + + # (Optional) print a brief summary of missing/loaded + print( + f"Loaded {len(loaded_keys)} params; left {len(missing_param_keys)} as" + " initialized." + ) + + return enc diff --git a/tunix/rl/common.py b/tunix/rl/common.py index 563da1d7f..9c5e00941 100644 --- a/tunix/rl/common.py +++ b/tunix/rl/common.py @@ -164,10 +164,18 @@ def get_per_token_logps( positions: jax.Array, attn_mask: jax.Array, logits_to_keep: int, + pixel_values=None, ) -> jax.Array | tuple[jax.Array, jax.Array]: """Computes the per-token log probabilities.""" + # logits, _ = model( + # input_tokens, positions=positions, attention_mask=attn_mask, cache=None + # ) logits, _ = model( - input_tokens, positions=positions, attention_mask=attn_mask, cache=None + input_tokens, + positions=positions, + attention_mask=attn_mask, + cache=None, + pixel_values=pixel_values, # ✅ add ) logits = logits[:, -logits_to_keep - 1 : -1, :] input_tokens = input_tokens[:, -logits_to_keep:] @@ -209,6 +217,7 @@ def compute_per_token_logps( completion_tokens: jax.Array, pad_id: int, eos_id: int, + pixel_values= None, completion_mask: jax.Array | None = None, stop_gradient: bool = True, return_logits: bool = False, @@ -217,8 +226,15 @@ def compute_per_token_logps( input_tokens, positions, attn_mask = process_ids( prompt_tokens, completion_tokens, pad_id, eos_id, completion_mask ) + # logits, _ = model( + # input_tokens, positions=positions, attention_mask=attn_mask, cache=None + # ) logits, _ = model( - input_tokens, positions=positions, attention_mask=attn_mask, cache=None + input_tokens, + positions=positions, + attention_mask=attn_mask, + cache=None, + pixel_values=pixel_values, ) logits_to_keep = completion_tokens.shape[1] logits = logits[:, -logits_to_keep - 1 : -1, :] diff --git a/tunix/sft/dpo/dpo_trainer.py b/tunix/sft/dpo/dpo_trainer.py index ba903da48..886a0c8dd 100644 --- a/tunix/sft/dpo/dpo_trainer.py +++ b/tunix/sft/dpo/dpo_trainer.py @@ -25,12 +25,15 @@ import jax.numpy as jnp import numpy as np import optax +from PIL import Image # TODO(abheesht): We should move TokenizerAdapter outside `generate`. from tunix.generate import tokenizer_adapter from tunix.rl import common from tunix.sft import peft_trainer from typing_extensions import override +ImageType = np.ndarray | jax.Array | Image.Image + @flax.struct.dataclass(frozen=True) class DataInput: @@ -40,12 +43,12 @@ class DataInput: preprocessing is taken care of by `DPOTrainer`. Attributes: - prompts: A list of prompts. + prompts: A list of either strings, or dicts with "text" and "image" keys. chosen_responses: A list of chosen responses. rejected_responses: A list of rejected responses. """ - prompts: list[str] + prompts: list[str | dict[str, str | ImageType]] chosen_responses: list[str] rejected_responses: list[str] @@ -59,6 +62,8 @@ class TrainingInput: Attributes: prompt_ids: Prompt IDs. Should be left-padded. prompt_mask: Prompt mask. Should be left-padded. + pixel_values: Optional pixels for multimodal inputs. Assumed same size + across batch if provided. chosen_ids: Chosen response IDs. Should be right-padded. chosen_mask: Chosen response mask. Should be right-padded. rejected_ids: Rejected response IDs. Should be right-padded. @@ -75,6 +80,8 @@ class TrainingInput: rejected_ids: jax.Array | np.ndarray rejected_mask: jax.Array | np.ndarray + pixel_values: jax.Array | np.ndarray | None = None + @flax.struct.dataclass(frozen=True) class TrainExample: @@ -85,6 +92,7 @@ class TrainExample: ref_rejected_logps: jax.Array | None completion_mask: jax.Array logits_to_keep: int = flax.struct.field(pytree_node=False) + pixel_values: jax.Array | None = None @dataclasses.dataclass(slots=True, kw_only=True) @@ -111,6 +119,7 @@ def compute_logps( attention_mask, logits_to_keep, completion_mask, + pixel_values=None, ): """Computes the log probabilities for chosen and rejected tokens.""" token_logps = common.get_per_token_logps( @@ -119,6 +128,7 @@ def compute_logps( positions=positions, attn_mask=attention_mask, logits_to_keep=logits_to_keep, + pixel_values=pixel_values, ) token_logps = (token_logps * completion_mask).sum(axis=-1) @@ -235,6 +245,14 @@ def __init__( if self.algorithm == "orpo": self._aux_metrics_to_log["odds_ratio"] = np.mean + + + _IMAGE_SOFT_TOKEN_ID = 262144 + _NUM_IMAGE_TOKENS = 256 + + + + @override def _prepare_inputs( self, @@ -291,7 +309,19 @@ def _prepare_inputs( # Compute positions, attention mask, etc., to be fed to the model. mask = jnp.concat([prompt_mask, completion_mask], axis=1) + + + # Pixel values (for multimodal): duplicate for chosen+rejected forward. + pixel_values = training_input.pixel_values + if pixel_values is not None: + pixel_values = jnp.concatenate([pixel_values, pixel_values], axis=0) + attention_mask = common.make_causal_attn_mask(mask) + # If we have pixel_values, assume multimodal and enable Gemma3 image-token rules. + if pixel_values is not None: + attention_mask = self.model.make_mm_attention_mask(input_ids, mask) + + logits_to_keep = completion_ids.shape[1] positions = common.build_positions_from_mask(mask) @@ -306,6 +336,7 @@ def _prepare_inputs( attention_mask, logits_to_keep, completion_mask, + pixel_values=pixel_values, ) return TrainExample( input_ids=input_ids, @@ -315,6 +346,7 @@ def _prepare_inputs( ref_rejected_logps=ref_rejected_logps, completion_mask=completion_mask, logits_to_keep=logits_to_keep, + pixel_values=pixel_values, ) @override @@ -374,6 +406,7 @@ def dpo_loss_fn( train_example.attention_mask, train_example.logits_to_keep, train_example.completion_mask, + pixel_values=train_example.pixel_values, ) if algorithm == "orpo": @@ -459,36 +492,51 @@ def dpo_loss_fn( def _generate_ids_and_masks( - input_strings: list[str], + inputs: list[str | dict[str, str | ImageType]], tokenizer: Any, max_length: int, left_pad: bool = True, ) -> tuple[jax.Array, jax.Array]: """Generates ids and masks for a list of strings.""" - tokens = [_tokenize(x, tokenizer) for x in input_strings] + tokens, pixel_values = zip(*[_tokenize(x, tokenizer) for x in inputs]) all_input_ids = jnp.array([ common.pad_to_length( - x[:max_length], + input_ids[:max_length], target_length=max_length, pad_value=tokenizer.pad_id(), left=left_pad, axis=-1, ) - for x in tokens + for input_ids in tokens ]) + if pixel_values[0] is not None: + assert all(pv.shape == pixel_values[0].shape for pv in pixel_values) + all_pixel_values = jnp.concat(pixel_values) + else: + all_pixel_values = None # generate masks all_input_mask = (all_input_ids != tokenizer.pad_id()).astype("int32") - return all_input_ids, all_input_mask + return all_input_ids, all_input_mask, all_pixel_values -def _tokenize(input_string: str, tokenizer: Any) -> jax.Array: +def _tokenize( + inp: str | dict[str, str | ImageType], tokenizer: Any +) -> tuple[jax.Array, jax.Array | None]: """Tokenizes the input string.""" - input_ids = tokenizer.encode(input_string) + if isinstance(inp, str): + input_ids = tokenizer.encode(inp) + pixel_values = None + elif "text" in inp.keys() and "image" in inp.keys(): + input_ids, pixel_values = tokenizer.encode(inp["text"], images=inp["image"]) + else: + raise ValueError( + "expected either str input or dict with 'text' and 'image' keys." + ) bos_tok = [tokenizer.bos_id()] if tokenizer.bos_id() else [] input_ids = jnp.array( tokenizer.dedup_bos_ids(bos_tok + input_ids), dtype=jnp.int32 ) - return input_ids + return input_ids, pixel_values def _preprocess_dict( @@ -496,33 +544,42 @@ def _preprocess_dict( ) -> DataInput | TrainingInput: """Wraps input dict with either DataInput or TrainingInput.""" - training_input_fields = [ - field.name for field in dataclasses.fields(DataInput) - ] + data_input_fields = [field.name for field in dataclasses.fields(DataInput)] tokenized_input_fields = [ field.name for field in dataclasses.fields(TrainingInput) ] # If the dict contains tokenized fields, we should wrap it with # TrainingInput. - if all(field in training_input for field in tokenized_input_fields): - return TrainingInput( - **{field: training_input[field] for field in tokenized_input_fields} - ) - elif all(field in training_input for field in training_input_fields): + if all( + field in training_input + for field in tokenized_input_fields + if field != "pixel_values" + ): + return TrainingInput(**{ + field: training_input.get(field, None) + for field in tokenized_input_fields + }) + elif all(field in training_input for field in data_input_fields): return DataInput( - **{field: training_input[field] for field in training_input_fields} + **{field: training_input[field] for field in data_input_fields} ) else: raise ValueError( "Training input must contain either tokenized fields " - f"({training_input_fields}) or raw string fields " - f"({training_input_fields}). Received: {training_input.keys()}." + f"({tokenized_input_fields}) or raw string fields " + f"({data_input_fields}). Received: {training_input.keys()}." ) def process_dpo_record( - record: dict[str, str | list[str]], + record: dict[ + str, + str + | list[str] + | dict[str, str | ImageType] + | list[dict[str, str | ImageType]], + ], tokenizer: Any, max_prompt_length: int, max_response_length: int, @@ -538,9 +595,11 @@ def process_dpo_record( Args: record: A dictionary, containing "prompts", "chosen_responses", and - "rejected_responses" as keys. The values can be a single string or a - list of strings. - tokenizer: The tokenizer to use for converting text into token IDs. + "rejected_responses". Each field can be a single string or a list of + strings, and prompts can additionally be a single dict or list of dicts + with "text" and "image" keys for multimodal inputs. + tokenizer: The tokenizer or processor to use for converting text into + token IDs. max_prompt_length: The maximum length for the tokenized prompts. Any sequence longer than this will be truncated. max_response_length: The maximum length for the tokenized responses. Any @@ -554,7 +613,7 @@ def process_dpo_record( chosen_responses = record["chosen_responses"] rejected_responses = record["rejected_responses"] - unbatched = isinstance(prompts, str) + unbatched = isinstance(prompts, (str, dict)) if unbatched: prompts = [prompts] @@ -564,16 +623,16 @@ def process_dpo_record( rejected_responses = [rejected_responses] # Only prompt is left padded, others are right padded. - prompt_ids, prompt_mask = _generate_ids_and_masks( + prompt_ids, prompt_mask, pixel_values = _generate_ids_and_masks( prompts, tokenizer, max_prompt_length, left_pad=True, ) - chosen_ids, chosen_mask = _generate_ids_and_masks( + chosen_ids, chosen_mask, _ = _generate_ids_and_masks( chosen_responses, tokenizer, max_response_length, left_pad=False ) - rejected_ids, rejected_mask = _generate_ids_and_masks( + rejected_ids, rejected_mask, _ = _generate_ids_and_masks( rejected_responses, tokenizer, max_response_length, left_pad=False ) @@ -584,6 +643,8 @@ def process_dpo_record( prompt_mask = jnp.squeeze(prompt_mask, axis=0) chosen_mask = jnp.squeeze(chosen_mask, axis=0) rejected_mask = jnp.squeeze(rejected_mask, axis=0) + if pixel_values is not None: + pixel_values = jnp.squeeze(pixel_values, axis=0) return TrainingInput( prompt_ids=prompt_ids, @@ -592,6 +653,7 @@ def process_dpo_record( chosen_mask=chosen_mask, rejected_ids=rejected_ids, rejected_mask=rejected_mask, + pixel_values=pixel_values, )