Pass group_size to RMSNormGated in Mamba2Simple for ngroups > 1#910
Open
Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Open
Pass group_size to RMSNormGated in Mamba2Simple for ngroups > 1#910Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Conversation
Mamba2Simple forwards ngroups throughout the block — d_in_proj, conv, and the B/C reshape all scale with ngroups — but the final RMSNormGated is constructed without group_size, which defaults to a single group spanning d_inner. The non-simple Mamba2 in mamba2.py handles this correctly by passing group_size=d_ssm // ngroups. With ngroups=1 (the default) the two paths are equivalent, but ngroups > 1 silently normalizes across groups that should be independent. Mirror the mamba2.py behavior so Mamba2Simple is correct for any ngroups; the default ngroups=1 case is unchanged since d_inner // 1 == d_inner is equivalent to group_size=None. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bug
Mamba2Simplethreadsngroupsthrough the rest of the block —d_in_projincludes2 * ngroups * d_state, the conv dim isd_inner + 2 * ngroups * d_state, andB/Care rearranged withg=self.ngroups— but the finalRMSNormGatedis built withoutgroup_size:https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py#L120
```python
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
```
RMSNormGatedtreatsgroup_size=Noneas "one group covering the whole hidden_size" (seemamba_ssm/ops/triton/layernorm_gated.pyline 418-419), so withngroups > 1the gated RMSNorm collapses all heads into a single normalization group, rather than normalizing each group independently.The non-simple Mamba2 handles this correctly in
mamba2.py:```python
self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
group_size=self.d_ssm // ngroups, **factory_kwargs)
```
Fix
Mirror the
mamba2.pypattern and passgroup_size=self.d_inner // ngroups. With the defaultngroups=1this evaluates tod_inner, which is equivalent togroup_size=None, so the default configuration is unchanged. Users who constructMamba2Simple(..., ngroups=g)withg > 1now get group-wise normalization consistent withMamba2.