Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
052a484 to
03a1804
Compare
03a1804 to
0b1d000
Compare
|
@immuntasir - let me know when this is ready for review! |
@Tianjiao-Yu confirmed that this is ready for review. |
examples/dpo_demo_gemma3.ipynb
Outdated
There was a problem hiding this comment.
@Tianjiao-Yu I think this should be removed from this PR.
abheesht17
left a comment
There was a problem hiding this comment.
Quick review, I'll do another pass tomorrow
examples/vl_dpo_demo_gemma3.ipynb
Outdated
| "source": [ | ||
| "# Fine-tuning a Visual Language Model (VLM) using DPO\n", | ||
| "\n", | ||
| "This notebook demonstrates how to fine-tune a Visual Language Model (VLM), specifically the Gemma 3-1B-it model, using the Direct Preference Optimization (DPO) algorithm.\n", |
There was a problem hiding this comment.
Gemma 3-1B-it model
This is a text-only model though. 4B onwards are VLMs
tunix/sft/dpo/dpo_trainer.py
Outdated
|
|
||
| This can be used when inputs are raw strings. Tokenization, padding and | ||
| preprocessing is taken care of by `DPOTrainer`. | ||
| preprocessing is taken care of by `DpoTrainer`. |
examples/dpo_demo_gemma3.ipynb
Outdated
| elif self._tokenizer_type == TokenizerType.HFP: | ||
| inputs = self._tokenizer(text=text, **kwargs) | ||
| if 'images' in kwargs: | ||
| return inputs['input_ids'], inputs['pixel_values'] |
There was a problem hiding this comment.
Better to return a dictionary here rather than a tuple (in case we add more modalities later)?
| HF: str = 'hf' # huggingface tokenizer | ||
| HFP: str = 'hfp' # huggingface processor |
There was a problem hiding this comment.
Is the only difference between these two that the processor can take images, and other modalities too? If yes, do you think we should just use HF processor everywhere (and remove HF tokeniser)?
Because if processor(text) works, we can just use processor everywhere
There was a problem hiding this comment.
I don't think every tokenizer has an associated processor definition, so it probably makes sense to have both.
tunix/generate/utils.py
Outdated
| # Defaults compatible with CLIP / many SigLIP configs; override if needed. | ||
| _CLIP_MEAN = jnp.array([0.48145466, 0.4578275, 0.40821073], dtype=jnp.float32) | ||
| _CLIP_STD = jnp.array([0.26862954, 0.26130258, 0.27577711], dtype=jnp.float32) |
There was a problem hiding this comment.
Do you think we can move it to models/siglip?
tunix/generate/utils.py
Outdated
| mean: Iterable[float] = _CLIP_MEAN, | ||
| std: Iterable[float] = _CLIP_STD, | ||
| ) -> jnp.ndarray: | ||
| """Resize + normalize images for SigLIP. |
There was a problem hiding this comment.
Just SigLIP? Does it not work for other vision models? In generate/utils.py, we should have generic functions (as much as possible)
|
|
||
| if self.config.multimodal: | ||
| assert pixel_values is not None | ||
| image_mask = last_tokens == 262144 # 262144: <image_soft_token> |
There was a problem hiding this comment.
Better to define this somewhere instead of hardcoding
|
Oh, I didn't mean to request so many reviews. Not sure how that happened. Maybe from the CLA failing? |
5afdf7b to
bed8fa7
Compare
|
Rebased to fix emails for CLA. |
abheesht17
left a comment
There was a problem hiding this comment.
Sorry, getting back to this. This LGTM! Could you please resolve the merge conflicts?
|
Hi, Abheesht, I made a small update to switch the demo to use Gemma-3-4B-IT, which is a true VLM, instead of Gemma-3-1B-IT. Could you please take a quick look when you have a moment? After that, I’ll proceed with resolving the remaining merge conflicts. Thanks! |
|
Hi guys! First of all, thanks for adding multimodal support - it works great in my use case, and I can't wait to see it merged! I noticed that currently number of images is restricted to one per input, i.e.
def patchify_images(
images: typing.Float["B H W C"],
patch_size: int = _DEFAULT_PATCH_SIZE,
padding: str = "VALID",
) -> typing.Float["P D"]: ...
def patchify_images(self, images: Float["*B H W C"]) -> Float["*B P D"]: ...From a quick test, I think to add support for multiple images here it should be enough to add batch dimension in a few place like @jax.named_scope("patch_embed")
def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
# x: [B,H,W,3] -> conv -> [B,H/P,W/P,D] -> [B,N,D]
x = self.proj(x)
*b, h, w, d = x.shape # was: b, h, w, d = x.shape
x = x.reshape(*b, h * w, d) # was: x = x.reshape(b, h * w, d)
x = shard(x, self.cfg.shd_config.act_bnd)
return xBut my local branch diverged significantly from this PR, so I can't really test the full fix just yet. |
Looks good. Could you please give me edit access to this branch? I'll resolve merge conflicts and make a few changes (especially regarding the multiple images point). Thanks! |
|
@abheesht17 For the multiple image support, you may want to look at this commit as a reference point (mostly files Another change that you might be interested in is saving LoRA params for multimodal Gemma. Alternatively, I can create a clear pull request for it after the current PR is merged. |
There was a problem hiding this comment.
I was going through this again, and found a few issues:
- Image tokens should have bidirectional attention, but I don't see that in the code.
- We should support multiple images.
- We have a Hugging Face preprocessor (
"hfp"), but we don't seem to be using it. Also, the special tokens in HF preprocessor/tokeniser are different from the upstream GDM implementation. - Gemma 3 uses special start of image tokens, end of image tokens, etc., which are not there in the code.
I have a WIP PR for resolving some of these issues. Give me some time.
tunix/models/siglip/preprocess.py
Outdated
| _CLIP_STD = jnp.array([0.26862954, 0.26130258, 0.27577711], dtype=jnp.float32) | ||
|
|
||
|
|
||
| def preprocess( |
There was a problem hiding this comment.
I don't see this function being used anywhere
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)" |
There was a problem hiding this comment.
Why can we not use the HF processor directly?
| "model_config = dataclasses.replace(\n", | ||
| " model_config, multimodal=True, num_embed=262208\n", |
There was a problem hiding this comment.
Why don't we just expose multimodal as an arg in gemma3_model_lib.ModelConfig.gemma3_4b(multimodal=True)?
| @@ -927,18 +1001,26 @@ def __call__( | |||
| positions: jaxtyping.Array, # [B, L] | |||
| cache: Cache | None, # (sequence length L') | |||
| attention_mask: jaxtyping.Array, # [B, L, L'] | |||
There was a problem hiding this comment.
Gemma 3 is supposed to have bidirectional attention for image tokens, but I don't see that here, or in the VLM DPO notebook.
…ed SigLIP preprocess
Resolves #510
This PR introduces multimodal support to Tunix’s Gemma3 model and adds a new vision-language DPO demonstration notebook (vl_dpo_demo_gemma3.ipynb), extending the framework to handle image-text reasoning and multimodal alignment. Key changes includes:
@Tianjiao-Yu led this effort and @jxiong21029 contributed to the Gemma3 integration. Please also mention @Tianjiao-Yu if you have any questions/comments/feedback.
Colab Notebook
vl_dpo_demo_gemma3.ipynb
Checklist