Skip to content

Uiuc vlm pr compressed fixed#511

Open
immuntasir wants to merge 35 commits intogoogle:mainfrom
PLAN-Lab:uiuc-vlm-pr-compressed-fixed
Open

Uiuc vlm pr compressed fixed#511
immuntasir wants to merge 35 commits intogoogle:mainfrom
PLAN-Lab:uiuc-vlm-pr-compressed-fixed

Conversation

@immuntasir
Copy link
Contributor

Resolves #510

This PR introduces multimodal support to Tunix’s Gemma3 model and adds a new vision-language DPO demonstration notebook (vl_dpo_demo_gemma3.ipynb), extending the framework to handle image-text reasoning and multimodal alignment. Key changes includes:

  • Multimodal Gemma3 Model:
    1. Added SigLIP vision encoder and MultimodalProjector to Gemma3, enabling joint image–text forward passes.
    2. Updated ModelConfig and parameter mapping to support multimodal checkpoints (multimodal=True).
  • DPO Example Notebook: Introduced examples/vl_dpo_demo_gemma3.ipynb, demonstrating multimodal DPO fine-tuning with image-conditioned prompts using multimodal gemma3
  • Gemma3 Tokenizer Adapter: Extended TokenizerAdapter to support Hugging Face Processor objects
  • VLM Sampler and Utils:
    1. Added VLMSampler for multimodal generation (tunix/generate/vlm_sampler.py)
    2. Added preprocess_image() utility in tunix/generate/utils.py for SigLIP/CLIP-style normalization.
  • DPO Pipeline Updates: Modified tunix/sft/dpo/dpo_trainer.py to handle multimodal data ({“text”, “image”} prompts) and propagate image tensors through training.

@Tianjiao-Yu led this effort and @jxiong21029 contributed to the Gemma3 integration. Please also mention @Tianjiao-Yu if you have any questions/comments/feedback.

Colab Notebook
vl_dpo_demo_gemma3.ipynb

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@google-cla
Copy link

google-cla bot commented Oct 6, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@jxiong21029 jxiong21029 force-pushed the uiuc-vlm-pr-compressed-fixed branch from 052a484 to 03a1804 Compare October 9, 2025 15:37
@Tianjiao-Yu Tianjiao-Yu force-pushed the uiuc-vlm-pr-compressed-fixed branch from 03a1804 to 0b1d000 Compare October 10, 2025 02:48
@abheesht17
Copy link
Collaborator

@immuntasir - let me know when this is ready for review!

@abheesht17 abheesht17 self-requested a review October 10, 2025 04:05
@immuntasir
Copy link
Contributor Author

@immuntasir - let me know when this is ready for review!

@Tianjiao-Yu confirmed that this is ready for review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tianjiao-Yu I think this should be removed from this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick review, I'll do another pass tomorrow

"source": [
"# Fine-tuning a Visual Language Model (VLM) using DPO\n",
"\n",
"This notebook demonstrates how to fine-tune a Visual Language Model (VLM), specifically the Gemma 3-1B-it model, using the Direct Preference Optimization (DPO) algorithm.\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma 3-1B-it model

This is a text-only model though. 4B onwards are VLMs


This can be used when inputs are raw strings. Tokenization, padding and
preprocessing is taken care of by `DPOTrainer`.
preprocessing is taken care of by `DpoTrainer`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

elif self._tokenizer_type == TokenizerType.HFP:
inputs = self._tokenizer(text=text, **kwargs)
if 'images' in kwargs:
return inputs['input_ids'], inputs['pixel_values']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to return a dictionary here rather than a tuple (in case we add more modalities later)?

Comment on lines 29 to +30
HF: str = 'hf' # huggingface tokenizer
HFP: str = 'hfp' # huggingface processor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the only difference between these two that the processor can take images, and other modalities too? If yes, do you think we should just use HF processor everywhere (and remove HF tokeniser)?

Because if processor(text) works, we can just use processor everywhere

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think every tokenizer has an associated processor definition, so it probably makes sense to have both.

Comment on lines 30 to 32
# Defaults compatible with CLIP / many SigLIP configs; override if needed.
_CLIP_MEAN = jnp.array([0.48145466, 0.4578275, 0.40821073], dtype=jnp.float32)
_CLIP_STD = jnp.array([0.26862954, 0.26130258, 0.27577711], dtype=jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can move it to models/siglip?

mean: Iterable[float] = _CLIP_MEAN,
std: Iterable[float] = _CLIP_STD,
) -> jnp.ndarray:
"""Resize + normalize images for SigLIP.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just SigLIP? Does it not work for other vision models? In generate/utils.py, we should have generic functions (as much as possible)


if self.config.multimodal:
assert pixel_values is not None
image_mask = last_tokens == 262144 # 262144: <image_soft_token>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to define this somewhere instead of hardcoding

@jxiong21029
Copy link

jxiong21029 commented Nov 1, 2025

Oh, I didn't mean to request so many reviews. Not sure how that happened. Maybe from the CLA failing?

@jxiong21029 jxiong21029 force-pushed the uiuc-vlm-pr-compressed-fixed branch 7 times, most recently from 5afdf7b to bed8fa7 Compare November 1, 2025 22:53
@jxiong21029
Copy link

Rebased to fix emails for CLA.

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, getting back to this. This LGTM! Could you please resolve the merge conflicts?

@Tianjiao-Yu
Copy link

Hi, Abheesht, I made a small update to switch the demo to use Gemma-3-4B-IT, which is a true VLM, instead of Gemma-3-1B-IT. Could you please take a quick look when you have a moment? After that, I’ll proceed with resolving the remaining merge conflicts. Thanks!

@ridcl
Copy link

ridcl commented Jan 4, 2026

Hi guys! First of all, thanks for adding multimodal support - it works great in my use case, and I can't wait to see it merged!

I noticed that currently number of images is restricted to one per input, i.e. pixel_values is assumed to be of size [B, H, W, C], while original Gemma implementation actually allows inputs of size [B, N, H, W, C]. More precisely, in the original Gemma implementation, there are two versions of patchiy_images:

  1. gemma.mutlimodel.image.patchify_images with signature (legacy?):
def patchify_images(
    images: typing.Float["B H W C"],
    patch_size: int = _DEFAULT_PATCH_SIZE,
    padding: str = "VALID",
) -> typing.Float["P D"]: ...
  1. gemma.multimodal.vision.patchify_images with signature (actually used in transformer):
def patchify_images(self, images: Float["*B H W C"]) -> Float["*B P D"]: ...

From a quick test, I think to add support for multiple images here it should be enough to add batch dimension in a few place like PatchEmbed.__call__():

  @jax.named_scope("patch_embed")
  def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
    # x: [B,H,W,3] -> conv -> [B,H/P,W/P,D] -> [B,N,D]
    x = self.proj(x)
    *b, h, w, d = x.shape                       # was:  b, h, w, d = x.shape
    x = x.reshape(*b, h * w, d)             # was: x = x.reshape(b, h * w, d)
    x = shard(x, self.cfg.shd_config.act_bnd)
    return x

But my local branch diverged significantly from this PR, so I can't really test the full fix just yet.

@abheesht17
Copy link
Collaborator

Hi, Abheesht, I made a small update to switch the demo to use Gemma-3-4B-IT, which is a true VLM, instead of Gemma-3-1B-IT. Could you please take a quick look when you have a moment? After that, I’ll proceed with resolving the remaining merge conflicts. Thanks!

Looks good. Could you please give me edit access to this branch? I'll resolve merge conflicts and make a few changes (especially regarding the multiple images point). Thanks!

@ridcl
Copy link

ridcl commented Jan 23, 2026

@abheesht17 For the multiple image support, you may want to look at this commit as a reference point (mostly files tunix/models/gemma3/model.py, tunix/models/siglip/model.py and tunix/models/siglip/preprocess.py, the rest should be formatting).

Another change that you might be interested in is saving LoRA params for multimodal Gemma. Alternatively, I can create a clear pull request for it after the current PR is merged.

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going through this again, and found a few issues:

  • Image tokens should have bidirectional attention, but I don't see that in the code.
  • We should support multiple images.
  • We have a Hugging Face preprocessor ("hfp"), but we don't seem to be using it. Also, the special tokens in HF preprocessor/tokeniser are different from the upstream GDM implementation.
  • Gemma 3 uses special start of image tokens, end of image tokens, etc., which are not there in the code.

I have a WIP PR for resolving some of these issues. Give me some time.

_CLIP_STD = jnp.array([0.26862954, 0.26130258, 0.27577711], dtype=jnp.float32)


def preprocess(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this function being used anywhere

},
"outputs": [],
"source": [
"gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can we not use the HF processor directly?

Comment on lines +182 to +183
"model_config = dataclasses.replace(\n",
" model_config, multimodal=True, num_embed=262208\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just expose multimodal as an arg in gemma3_model_lib.ModelConfig.gemma3_4b(multimodal=True)?

@@ -927,18 +1001,26 @@ def __call__(
positions: jaxtyping.Array, # [B, L]
cache: Cache | None, # (sequence length L')
attention_mask: jaxtyping.Array, # [B, L, L']
Copy link
Collaborator

@abheesht17 abheesht17 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma 3 is supposed to have bidirectional attention for image tokens, but I don't see that here, or in the VLM DPO notebook.

@ridcl ridcl mentioned this pull request Feb 7, 2026
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

VLM support for DPO

5 participants