Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
73e41a0
Add SigLIP and Paligemma code
Tianjiao-Yu Sep 8, 2025
191dbd7
Add multimodal support for Gemma 3
jxiong21029 Sep 13, 2025
756c3d6
Begin implementing VL-DPO trainer
immuntasir Sep 13, 2025
154c464
Add DPO demo
Tianjiao-Yu Sep 14, 2025
a583fbe
Add dpo_demo_gemma3_v2.ipynb
Tianjiao-Yu Sep 14, 2025
3880a16
Edit VL-DPO trainer
Tianjiao-Yu Sep 15, 2025
d48af92
Add fixes for dpo_demo_gemma3_v2.ipynb and create dpo_demo_gemma3_v3.…
Tianjiao-Yu Sep 16, 2025
969d930
Modify tests for vl_dpo_trainer
immuntasir Sep 22, 2025
100dea3
Begin integrating image inputs support in DPOTrainer
jxiong21029 Sep 25, 2025
611962d
Remove unused files and update vl_dpo_demo_gemma3
immuntasir Sep 29, 2025
f4038a9
restore siglip code (needed for multimodal Gemma models)
jxiong21029 Nov 1, 2025
d45d0d6
Add SigLIP and Paligemma code
Tianjiao-Yu Oct 10, 2025
5413a10
Add multimodal support for Gemma 3
jxiong21029 Sep 13, 2025
74de87a
Begin implementing VL-DPO trainer
immuntasir Sep 13, 2025
1c26c2f
Add DPO demo
Tianjiao-Yu Oct 10, 2025
4327978
Add dpo_demo_gemma3_v2.ipynb
Tianjiao-Yu Oct 10, 2025
9490b0f
Edit VL-DPO trainer
Tianjiao-Yu Oct 10, 2025
dc6510a
Add fixes for dpo_demo_gemma3_v2.ipynb and create dpo_demo_gemma3_v3.…
Tianjiao-Yu Oct 10, 2025
844c7a0
Modify tests for vl_dpo_trainer
immuntasir Sep 22, 2025
583194f
Begin integrating image inputs support in DPOTrainer
jxiong21029 Sep 25, 2025
e5e2c36
Remove unused files and update vl_dpo_demo_gemma3
immuntasir Sep 29, 2025
0a463ee
Fix: gemma3 modelconfig imported with wrong class name
Tianjiao-Yu Oct 10, 2025
f548009
Fix safetensors loading for multimodal Gemma3 models
jxiong21029 Nov 1, 2025
9220769
VL-DPO notebook fixes
jxiong21029 Nov 1, 2025
97fd343
fix
jxiong21029 Nov 1, 2025
6350957
More VL-DPO notebook fixes
jxiong21029 Nov 1, 2025
0c46733
Remove VLMSampler; move preprocess_image from generate/utils.py to VL…
jxiong21029 Nov 1, 2025
5d3ecec
Delete examples/dpo_demo_gemma3.ipynb
jxiong21029 Nov 1, 2025
8c3e370
Revert dpo_demo_gemma3.ipynb
jxiong21029 Nov 1, 2025
bed8fa7
Restore SigLIP
jxiong21029 Nov 1, 2025
afe0d52
Fix multimodal DPO by forwarding pixel_values through logp computation
Tianjiao-Yu Dec 15, 2025
9c118e8
Forward pixel_values in get_per_token_logps for multimodal DPO
Tianjiao-Yu Dec 15, 2025
369ef14
fix notebook example for multimodal demo
Tianjiao-Yu Dec 16, 2025
d170ceb
Fix merge conflicts
abheesht17 Jan 26, 2026
1c4d1ff
Add Gemma3 multimodal attention mask, update DPO trainer, remove unus…
Tianjiao-Yu Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
667 changes: 667 additions & 0 deletions examples/vl_dpo_demo_gemma3.ipynb

Large diffs are not rendered by default.

33 changes: 28 additions & 5 deletions tunix/generate/tokenizer_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
class TokenizerType(enum.Enum):
SP: str = 'sp' # sentencepiece tokenizer
HF: str = 'hf' # huggingface tokenizer
HFP: str = 'hfp' # huggingface processor
Comment on lines 29 to +30
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the only difference between these two that the processor can take images, and other modalities too? If yes, do you think we should just use HF processor everywhere (and remove HF tokeniser)?

Because if processor(text) works, we can just use processor everywhere

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think every tokenizer has an associated processor definition, so it probably makes sense to have both.

NONE: str = 'none' # Represents no tokenizer


Expand All @@ -42,6 +43,8 @@ def __init__(self, tokenizer: Any):
self._tokenizer_type = TokenizerType.SP
elif self._is_hf_tokenizer():
self._tokenizer_type = TokenizerType.HF
elif self._is_hf_processor():
self._tokenizer_type = TokenizerType.HFP
elif not missing_methods:
self._tokenizer_type = TokenizerType.NONE
else:
Expand All @@ -54,11 +57,16 @@ def __init__(self, tokenizer: Any):
f'{missing_methods}.'
)

def encode(self, text: str, **kwargs) -> list[int]:
def encode(self, text: str, **kwargs) -> list[int] | tuple[list[int], Any]:
if self._tokenizer_type == TokenizerType.SP:
return self._tokenizer.EncodeAsIds(text, **kwargs)
elif self._tokenizer_type == TokenizerType.HF:
return self._tokenizer.encode(text, **kwargs)
elif self._tokenizer_type == TokenizerType.HFP:
inputs = self._tokenizer(text=text, **kwargs)
if 'images' in kwargs:
return inputs['input_ids'], inputs['pixel_values']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to return a dictionary here rather than a tuple (in case we add more modalities later)?

return inputs['input_ids']
else:
return self._tokenizer.encode(text, **kwargs)

Expand All @@ -67,6 +75,8 @@ def decode(self, ids: list[int], **kwargs) -> str:
return self._tokenizer.DecodeIds(ids, **kwargs)
elif self._tokenizer_type == TokenizerType.HF:
return self._tokenizer.decode(ids, **kwargs)
elif self._tokenizer_type == TokenizerType.HFP:
return self._tokenizer.tokenizer.decode(ids, **kwargs)
else:
return self._tokenizer.decode(ids, **kwargs)

Expand All @@ -75,6 +85,8 @@ def bos_id(self) -> int:
return self._tokenizer.bos_id()
elif self._tokenizer_type == TokenizerType.HF:
return self._tokenizer.bos_token_id
elif self._tokenizer_type == TokenizerType.HFP:
return self._tokenizer.tokenizer.bos_token_id
else:
return self._tokenizer.bos_id()

Expand All @@ -83,6 +95,8 @@ def eos_id(self) -> int:
return self._tokenizer.eos_id()
elif self._tokenizer_type == TokenizerType.HF:
return self._tokenizer.eos_token_id
elif self._tokenizer_type == TokenizerType.HFP:
return self._tokenizer.tokenizer.eos_token_id
else:
return self._tokenizer.eos_id()

Expand All @@ -98,6 +112,8 @@ def pad_id(self) -> int:
if self._tokenizer.pad_token_id is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
return self._tokenizer.pad_token_id
elif self._tokenizer_type == TokenizerType.HFP:
return self._tokenizer.tokenizer.pad_token_id
else:
return self._tokenizer.pad_id()

Expand All @@ -124,12 +140,19 @@ def _is_hf_tokenizer(self) -> bool:
baseclass.__module__ + '.' + baseclass.__name__
for baseclass in baseclasses
]
if (
return (
'transformers.tokenization_utils_base.PreTrainedTokenizerBase'
in baseclass_names
):
return True
return False
)

def _is_hf_processor(self) -> bool:
"""Checks if the tokenizer is a huggingface Processor."""
baseclasses = inspect.getmro(type(self._tokenizer))
baseclass_names = [
baseclass.__module__ + '.' + baseclass.__name__
for baseclass in baseclasses
]
return 'transformers.processing_utils.ProcessorMixin' in baseclass_names

@property
def tokenizer(self) -> Any:
Expand Down
1 change: 1 addition & 0 deletions tunix/generate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flax import nnx
import jax
from jax import lax
import jax.image as jimg
import jax.numpy as jnp
import numpy as np

Expand Down
138 changes: 136 additions & 2 deletions tunix/models/gemma3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class ModelConfig:
num_heads: int
head_dim: int
num_kv_heads: int
multimodal: bool = False
sliding_window_size: int | None = None
local_base_frequency: int = 10_000
global_base_frequency: int = 10_000
Expand Down Expand Up @@ -877,11 +878,86 @@ def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
return normed_inputs


class MultimodalProjector(nnx.Module):
"""Image soft token pooling + projection."""

IMAGE_SOFT_TOKEN_ID: int = 262144

def __init__(
self,
vision_embed_dim: int,
text_embed_dim: int,
patches_per_side: int,
output_tokens_per_side=16,
*,
rngs: nnx.Rngs,
shd_config: ShardingConfig = ShardingConfig.get_default_sharding(),
):
self.patches_per_side = patches_per_side
self.output_tokens_per_side = output_tokens_per_side
self.output_tokens_total = output_tokens_per_side * output_tokens_per_side
self.kernel_size = patches_per_side // output_tokens_per_side

self.mm_soft_emb_norm = RMSNorm(
vision_embed_dim,
rngs=rngs,
sharding=shd_config.rms_norm_weight,
)
self.mm_input_projection = nnx.Linear(
in_features=vision_embed_dim,
out_features=text_embed_dim,
use_bias=False,
rngs=rngs,
kernel_init=nnx.with_partitioning(
nnx.initializers.zeros_init(), shd_config.ffw_weight_df
),
)

@jax.named_scope('multimodal_projector')
def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array:
B, _, D = x.shape
x = x.reshape(B, self.patches_per_side, self.patches_per_side, D)
x = nnx.avg_pool(
x,
window_shape=(self.kernel_size, self.kernel_size),
strides=(self.kernel_size, self.kernel_size),
)
x = x.reshape(B, self.output_tokens_total, D)
x = self.mm_soft_emb_norm(x)
x = self.mm_input_projection(x)
return x


class Gemma3(nnx.Module):
"""Gemma3 transformer."""
"""Gemma transformer."""

def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs):
self.config = config

if config.multimodal:
from tunix.models.siglip.model import SigLIPConfig, SigLIPEngine # pylint: disable=g-import-not-at-top

self.siglip = SigLIPEngine(
cfg=SigLIPConfig(
image_size=896,
patch_size=14,
embed_dim=1152,
depth=27,
num_heads=16,
mlp_hidden_dim=4304,
use_cls_token=False,
use_abs_pos_emb=True,
),
rngs=rngs,
)
self.projector = MultimodalProjector(
1152,
config.embed_dim,
64,
rngs=rngs,
shd_config=config.shd_config,
)

self.embedder = Embedder(
vocab_size=config.num_embed,
embed_dim=config.embed_dim,
Expand Down Expand Up @@ -927,18 +1003,26 @@ def __call__(
positions: jaxtyping.Array, # [B, L]
cache: Cache | None, # (sequence length L')
attention_mask: jaxtyping.Array, # [B, L, L']
Copy link
Collaborator

@abheesht17 abheesht17 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemma 3 is supposed to have bidirectional attention for image tokens, but I don't see that here, or in the VLM DPO notebook.

pixel_values: jaxtyping.Array | None = None, # [B, H, W, C]
output_hidden_states: bool = False,
) -> tuple[jaxtyping.Array, Cache | None]:
"""Transformer forward pass.

You can run this forward pass two ways: with or without an attention kv
cache.

Note: for multimodal (image + text) inputs: last_tokens is expected to be
already preprocessed to contain exactly 256 <image_soft_token> (id=262144)
per tokenized input, and attention_mask is expected to already have been
adjusted for image tokens, i.e. image tokens attend to all tokens in the
(same) image bidirectionally in addition to attending to all previous tokens

Args:
last_tokens: input sequence of tokens.
positions: input absolute positions.
cache: Attention KV cache or None.
attention_mask: transformer input mask.
pixel_values: (preprocessed) images for multimodal, None for text-only.
output_hidden_states: whether to output the hidden states.

Returns:
Expand All @@ -947,8 +1031,29 @@ def __call__(
predicted_logits: output logits predicted by the model
new_cache: updated cache if the input cache is not None, None elsewhere.
"""

new_cache = None if cache is None else {}
x = self.embedder.encode(last_tokens)

if self.config.multimodal:
assert pixel_values is not None
image_mask = last_tokens == 262144 # 262144: <image_soft_token>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to define this somewhere instead of hardcoding


vision_outputs = self.siglip(pixel_values) # B, 4096, 1152
image_features = self.projector(vision_outputs) # B, 256, embed_dim

last_tokens = jnp.where(image_mask, 0, last_tokens)
x = self.embedder.encode(last_tokens)
image_features = image_features.astype(x.dtype)

# Write image features to embedded input
idx = jnp.cumsum(image_mask, axis=1) - 1
idx = jnp.where(image_mask, idx, 0)
gathered = jnp.take_along_axis(image_features, idx[..., None], axis=1)
x = jnp.where(image_mask[..., None], gathered, x)

else:
x = self.embedder.encode(last_tokens)

for i, layer in enumerate(self.layers):
layer_name = f'layer_{i}'
layer_cache = cache[layer_name] if cache else None
Expand Down Expand Up @@ -989,7 +1094,36 @@ def get_model_input(self):
(dummy_batch_size, 1, dummy_seq_len), dtype=jnp.bool
),
}

@staticmethod
def make_mm_attention_mask(
input_ids: jaxtyping.Array, # [B, L]
input_mask: jaxtyping.Array, # [B, L] (1 for valid tokens, 0 for pad)
) -> jaxtyping.Array:
"""Builds Gemma3 multimodal attention mask.

- Base causal attention
- Text can attend to image tokens
- Image tokens attend bidirectionally to other image tokens
- Padding respected
"""

# Base causal mask (already handles pad keys)
from tunix.rl import common # local import avoids circular deps

attn = common.make_causal_attn_mask(input_mask) # [B, L, L]

image_mask = (input_ids == Gemma3.IMAGE_SOFT_TOKEN_ID) # [B, L]

# Allow any query to attend to image keys
attn = attn | image_mask[:, None, :]

# Fully open image <-> image attention
attn = attn | (image_mask[:, :, None] & image_mask[:, None, :])



return attn
@property
def embed_dim(self) -> int:
return self.embedder.embed_dim
Expand Down
71 changes: 65 additions & 6 deletions tunix/models/gemma3/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_model_from_checkpoint(
lambda: model_lib.Gemma3(model_config, rngs=nnx.Rngs(0))
)
params = ocp.StandardCheckpointer().restore(checkpoint_path)
params = map_from_upstream_checkpoint(params)
params = map_from_upstream_checkpoint(params, multimodal=model_config.multimodal)
if mesh is not None:
params = jax.tree.map(
lambda x, shd: jnp.asarray(x, device=shd, dtype=dtype),
Expand Down Expand Up @@ -88,7 +88,9 @@ def create_tokenizer(
return spm_processor


def map_from_upstream_checkpoint(params, model_type: str = 'gemma3'):
def map_from_upstream_checkpoint(
params, model_type: str = 'gemma3', multimodal: bool = False
):
"""Map from upstream checkpoint to our implementation."""
# From:
#
Expand Down Expand Up @@ -127,13 +129,70 @@ def map_from_upstream_checkpoint(params, model_type: str = 'gemma3'):
module_path, param_name = key_path
module_path = module_path.split('/')[1:] # Remove the leading 'transformer'
if module_path[0] == 'siglip_encoder':
continue # We don't support MM input yet.
if module_path[0] == 'embedder':
if len(module_path) > 1 and module_path[1].startswith('mm_'):
continue # We don't support MM input yet.
if not multimodal:
continue
if param_name == 'pos_embedding':
new_params[('siglip', 'pos_embed')] = value
continue
elif module_path[1] == 'embedding':
new_params[('siglip', 'patch', 'proj', param_name)] = value
continue
elif module_path[2] == 'encoder_norm':
new_params[('siglip', 'norm', param_name)] = value
continue

assert module_path[2].startswith('encoderblock_')
siglip_layer = (
'siglip',
'blocks',
int(module_path[2].removeprefix('encoderblock_')),
)

if module_path[3] == 'LayerNorm_0':
new_params[(*siglip_layer, 'ln1', param_name)] = value
elif module_path[3] == 'LayerNorm_1':
new_params[(*siglip_layer, 'ln2', param_name)] = value
elif module_path[3] == 'MultiHeadDotProductAttention_0':
if module_path[4] == 'out':
if value.ndim == 3:
value = value.reshape(-1, value.shape[-1])
else:
value = value.reshape(-1)
new_params[(*siglip_layer, 'attn', 'o', param_name)] = value
else:
if value.ndim == 3:
value = value.reshape(value.shape[0], -1)
else:
value = value.reshape(-1)
if module_path[4] == 'query':
new_params[(*siglip_layer, 'attn', 'q', param_name)] = value
elif module_path[4] == 'key':
new_params[(*siglip_layer, 'attn', 'k', param_name)] = value
else:
assert module_path[4] == 'value'
new_params[(*siglip_layer, 'attn', 'v', param_name)] = value
elif module_path[3:] == ['MlpBlock_0', 'Dense_0']:
new_params[(*siglip_layer, 'mlp', 'fc1', param_name)] = value
else:
assert module_path[3:] == ['MlpBlock_0', 'Dense_1']
new_params[(*siglip_layer, 'mlp', 'fc2', param_name)] = value
continue

if (
module_path[0] == 'embedder'
and len(module_path) > 1
and module_path[1].startswith('mm_')
):
if multimodal:
if module_path[1] == 'mm_soft_embedding_norm':
new_params[('projector', 'mm_soft_emb_norm', param_name)] = value
elif module_path[1] == 'mm_input_projection':
new_params[('projector', 'mm_input_projection', 'kernel')] = value
continue
if module_path[0] in ('embedder', 'final_norm'):
new_params[(module_path[0], param_name)] = value
continue

# module_path should now look like ('layer_0', 'attn', '_key_norm')
layer_idx = ('layers', int(module_path[0].removeprefix('layer_')))
if module_path[1:] == ['mlp', 'gating_einsum']:
Expand Down
Loading