-
Notifications
You must be signed in to change notification settings - Fork 238
Uiuc vlm pr compressed fixed #511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
73e41a0
191dbd7
756c3d6
154c464
a583fbe
3880a16
d48af92
969d930
100dea3
611962d
f4038a9
d45d0d6
5413a10
74de87a
1c26c2f
4327978
9490b0f
dc6510a
844c7a0
583194f
e5e2c36
0a463ee
f548009
9220769
97fd343
6350957
0c46733
5d3ecec
8c3e370
bed8fa7
afe0d52
9c118e8
369ef14
d170ceb
1c4d1ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| class TokenizerType(enum.Enum): | ||
| SP: str = 'sp' # sentencepiece tokenizer | ||
| HF: str = 'hf' # huggingface tokenizer | ||
| HFP: str = 'hfp' # huggingface processor | ||
| NONE: str = 'none' # Represents no tokenizer | ||
|
|
||
|
|
||
|
|
@@ -42,6 +43,8 @@ def __init__(self, tokenizer: Any): | |
| self._tokenizer_type = TokenizerType.SP | ||
| elif self._is_hf_tokenizer(): | ||
| self._tokenizer_type = TokenizerType.HF | ||
| elif self._is_hf_processor(): | ||
| self._tokenizer_type = TokenizerType.HFP | ||
| elif not missing_methods: | ||
| self._tokenizer_type = TokenizerType.NONE | ||
| else: | ||
|
|
@@ -54,11 +57,16 @@ def __init__(self, tokenizer: Any): | |
| f'{missing_methods}.' | ||
| ) | ||
|
|
||
| def encode(self, text: str, **kwargs) -> list[int]: | ||
| def encode(self, text: str, **kwargs) -> list[int] | tuple[list[int], Any]: | ||
| if self._tokenizer_type == TokenizerType.SP: | ||
| return self._tokenizer.EncodeAsIds(text, **kwargs) | ||
| elif self._tokenizer_type == TokenizerType.HF: | ||
| return self._tokenizer.encode(text, **kwargs) | ||
| elif self._tokenizer_type == TokenizerType.HFP: | ||
| inputs = self._tokenizer(text=text, **kwargs) | ||
| if 'images' in kwargs: | ||
| return inputs['input_ids'], inputs['pixel_values'] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to return a dictionary here rather than a tuple (in case we add more modalities later)? |
||
| return inputs['input_ids'] | ||
| else: | ||
| return self._tokenizer.encode(text, **kwargs) | ||
|
|
||
|
|
@@ -67,6 +75,8 @@ def decode(self, ids: list[int], **kwargs) -> str: | |
| return self._tokenizer.DecodeIds(ids, **kwargs) | ||
| elif self._tokenizer_type == TokenizerType.HF: | ||
| return self._tokenizer.decode(ids, **kwargs) | ||
| elif self._tokenizer_type == TokenizerType.HFP: | ||
| return self._tokenizer.tokenizer.decode(ids, **kwargs) | ||
| else: | ||
| return self._tokenizer.decode(ids, **kwargs) | ||
|
|
||
|
|
@@ -75,6 +85,8 @@ def bos_id(self) -> int: | |
| return self._tokenizer.bos_id() | ||
| elif self._tokenizer_type == TokenizerType.HF: | ||
| return self._tokenizer.bos_token_id | ||
| elif self._tokenizer_type == TokenizerType.HFP: | ||
| return self._tokenizer.tokenizer.bos_token_id | ||
| else: | ||
| return self._tokenizer.bos_id() | ||
|
|
||
|
|
@@ -83,6 +95,8 @@ def eos_id(self) -> int: | |
| return self._tokenizer.eos_id() | ||
| elif self._tokenizer_type == TokenizerType.HF: | ||
| return self._tokenizer.eos_token_id | ||
| elif self._tokenizer_type == TokenizerType.HFP: | ||
| return self._tokenizer.tokenizer.eos_token_id | ||
| else: | ||
| return self._tokenizer.eos_id() | ||
|
|
||
|
|
@@ -98,6 +112,8 @@ def pad_id(self) -> int: | |
| if self._tokenizer.pad_token_id is None: | ||
| self._tokenizer.pad_token = self._tokenizer.eos_token | ||
| return self._tokenizer.pad_token_id | ||
| elif self._tokenizer_type == TokenizerType.HFP: | ||
| return self._tokenizer.tokenizer.pad_token_id | ||
| else: | ||
| return self._tokenizer.pad_id() | ||
|
|
||
|
|
@@ -124,12 +140,19 @@ def _is_hf_tokenizer(self) -> bool: | |
| baseclass.__module__ + '.' + baseclass.__name__ | ||
| for baseclass in baseclasses | ||
| ] | ||
| if ( | ||
| return ( | ||
| 'transformers.tokenization_utils_base.PreTrainedTokenizerBase' | ||
| in baseclass_names | ||
| ): | ||
| return True | ||
| return False | ||
| ) | ||
|
|
||
| def _is_hf_processor(self) -> bool: | ||
| """Checks if the tokenizer is a huggingface Processor.""" | ||
| baseclasses = inspect.getmro(type(self._tokenizer)) | ||
| baseclass_names = [ | ||
| baseclass.__module__ + '.' + baseclass.__name__ | ||
| for baseclass in baseclasses | ||
| ] | ||
| return 'transformers.processing_utils.ProcessorMixin' in baseclass_names | ||
|
|
||
| @property | ||
| def tokenizer(self) -> Any: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,6 +98,7 @@ class ModelConfig: | |
| num_heads: int | ||
| head_dim: int | ||
| num_kv_heads: int | ||
| multimodal: bool = False | ||
| sliding_window_size: int | None = None | ||
| local_base_frequency: int = 10_000 | ||
| global_base_frequency: int = 10_000 | ||
|
|
@@ -877,11 +878,86 @@ def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: | |
| return normed_inputs | ||
|
|
||
|
|
||
| class MultimodalProjector(nnx.Module): | ||
| """Image soft token pooling + projection.""" | ||
|
|
||
| IMAGE_SOFT_TOKEN_ID: int = 262144 | ||
|
|
||
| def __init__( | ||
| self, | ||
| vision_embed_dim: int, | ||
| text_embed_dim: int, | ||
| patches_per_side: int, | ||
| output_tokens_per_side=16, | ||
| *, | ||
| rngs: nnx.Rngs, | ||
| shd_config: ShardingConfig = ShardingConfig.get_default_sharding(), | ||
| ): | ||
| self.patches_per_side = patches_per_side | ||
| self.output_tokens_per_side = output_tokens_per_side | ||
| self.output_tokens_total = output_tokens_per_side * output_tokens_per_side | ||
| self.kernel_size = patches_per_side // output_tokens_per_side | ||
|
|
||
| self.mm_soft_emb_norm = RMSNorm( | ||
| vision_embed_dim, | ||
| rngs=rngs, | ||
| sharding=shd_config.rms_norm_weight, | ||
| ) | ||
| self.mm_input_projection = nnx.Linear( | ||
| in_features=vision_embed_dim, | ||
| out_features=text_embed_dim, | ||
| use_bias=False, | ||
| rngs=rngs, | ||
| kernel_init=nnx.with_partitioning( | ||
| nnx.initializers.zeros_init(), shd_config.ffw_weight_df | ||
| ), | ||
| ) | ||
|
|
||
| @jax.named_scope('multimodal_projector') | ||
| def __call__(self, x: jaxtyping.Array) -> jaxtyping.Array: | ||
| B, _, D = x.shape | ||
| x = x.reshape(B, self.patches_per_side, self.patches_per_side, D) | ||
| x = nnx.avg_pool( | ||
| x, | ||
| window_shape=(self.kernel_size, self.kernel_size), | ||
| strides=(self.kernel_size, self.kernel_size), | ||
| ) | ||
| x = x.reshape(B, self.output_tokens_total, D) | ||
| x = self.mm_soft_emb_norm(x) | ||
| x = self.mm_input_projection(x) | ||
| return x | ||
|
|
||
|
|
||
| class Gemma3(nnx.Module): | ||
| """Gemma3 transformer.""" | ||
| """Gemma transformer.""" | ||
|
|
||
| def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs): | ||
| self.config = config | ||
|
|
||
| if config.multimodal: | ||
| from tunix.models.siglip.model import SigLIPConfig, SigLIPEngine # pylint: disable=g-import-not-at-top | ||
|
|
||
| self.siglip = SigLIPEngine( | ||
| cfg=SigLIPConfig( | ||
| image_size=896, | ||
| patch_size=14, | ||
| embed_dim=1152, | ||
| depth=27, | ||
| num_heads=16, | ||
| mlp_hidden_dim=4304, | ||
| use_cls_token=False, | ||
| use_abs_pos_emb=True, | ||
| ), | ||
| rngs=rngs, | ||
| ) | ||
| self.projector = MultimodalProjector( | ||
| 1152, | ||
| config.embed_dim, | ||
| 64, | ||
| rngs=rngs, | ||
| shd_config=config.shd_config, | ||
| ) | ||
|
|
||
| self.embedder = Embedder( | ||
| vocab_size=config.num_embed, | ||
| embed_dim=config.embed_dim, | ||
|
|
@@ -927,18 +1003,26 @@ def __call__( | |
| positions: jaxtyping.Array, # [B, L] | ||
| cache: Cache | None, # (sequence length L') | ||
| attention_mask: jaxtyping.Array, # [B, L, L'] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gemma 3 is supposed to have bidirectional attention for image tokens, but I don't see that here, or in the VLM DPO notebook. |
||
| pixel_values: jaxtyping.Array | None = None, # [B, H, W, C] | ||
| output_hidden_states: bool = False, | ||
| ) -> tuple[jaxtyping.Array, Cache | None]: | ||
| """Transformer forward pass. | ||
|
|
||
| You can run this forward pass two ways: with or without an attention kv | ||
| cache. | ||
|
|
||
| Note: for multimodal (image + text) inputs: last_tokens is expected to be | ||
| already preprocessed to contain exactly 256 <image_soft_token> (id=262144) | ||
| per tokenized input, and attention_mask is expected to already have been | ||
| adjusted for image tokens, i.e. image tokens attend to all tokens in the | ||
| (same) image bidirectionally in addition to attending to all previous tokens | ||
|
|
||
| Args: | ||
| last_tokens: input sequence of tokens. | ||
| positions: input absolute positions. | ||
| cache: Attention KV cache or None. | ||
| attention_mask: transformer input mask. | ||
| pixel_values: (preprocessed) images for multimodal, None for text-only. | ||
| output_hidden_states: whether to output the hidden states. | ||
|
|
||
| Returns: | ||
|
|
@@ -947,8 +1031,29 @@ def __call__( | |
| predicted_logits: output logits predicted by the model | ||
| new_cache: updated cache if the input cache is not None, None elsewhere. | ||
| """ | ||
|
|
||
| new_cache = None if cache is None else {} | ||
| x = self.embedder.encode(last_tokens) | ||
|
|
||
| if self.config.multimodal: | ||
| assert pixel_values is not None | ||
| image_mask = last_tokens == 262144 # 262144: <image_soft_token> | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to define this somewhere instead of hardcoding |
||
|
|
||
| vision_outputs = self.siglip(pixel_values) # B, 4096, 1152 | ||
| image_features = self.projector(vision_outputs) # B, 256, embed_dim | ||
|
|
||
| last_tokens = jnp.where(image_mask, 0, last_tokens) | ||
| x = self.embedder.encode(last_tokens) | ||
| image_features = image_features.astype(x.dtype) | ||
|
|
||
| # Write image features to embedded input | ||
| idx = jnp.cumsum(image_mask, axis=1) - 1 | ||
| idx = jnp.where(image_mask, idx, 0) | ||
| gathered = jnp.take_along_axis(image_features, idx[..., None], axis=1) | ||
| x = jnp.where(image_mask[..., None], gathered, x) | ||
|
|
||
| else: | ||
| x = self.embedder.encode(last_tokens) | ||
|
|
||
| for i, layer in enumerate(self.layers): | ||
| layer_name = f'layer_{i}' | ||
| layer_cache = cache[layer_name] if cache else None | ||
|
|
@@ -989,7 +1094,36 @@ def get_model_input(self): | |
| (dummy_batch_size, 1, dummy_seq_len), dtype=jnp.bool | ||
| ), | ||
| } | ||
|
|
||
| @staticmethod | ||
| def make_mm_attention_mask( | ||
| input_ids: jaxtyping.Array, # [B, L] | ||
| input_mask: jaxtyping.Array, # [B, L] (1 for valid tokens, 0 for pad) | ||
| ) -> jaxtyping.Array: | ||
| """Builds Gemma3 multimodal attention mask. | ||
|
|
||
| - Base causal attention | ||
| - Text can attend to image tokens | ||
| - Image tokens attend bidirectionally to other image tokens | ||
| - Padding respected | ||
| """ | ||
|
|
||
| # Base causal mask (already handles pad keys) | ||
| from tunix.rl import common # local import avoids circular deps | ||
|
|
||
| attn = common.make_causal_attn_mask(input_mask) # [B, L, L] | ||
|
|
||
| image_mask = (input_ids == Gemma3.IMAGE_SOFT_TOKEN_ID) # [B, L] | ||
|
|
||
| # Allow any query to attend to image keys | ||
| attn = attn | image_mask[:, None, :] | ||
|
|
||
| # Fully open image <-> image attention | ||
| attn = attn | (image_mask[:, :, None] & image_mask[:, None, :]) | ||
|
|
||
|
|
||
|
|
||
| return attn | ||
| @property | ||
| def embed_dim(self) -> int: | ||
| return self.embedder.embed_dim | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the only difference between these two that the processor can take images, and other modalities too? If yes, do you think we should just use HF processor everywhere (and remove HF tokeniser)?
Because if
processor(text)works, we can just use processor everywhereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think every tokenizer has an associated processor definition, so it probably makes sense to have both.