Skip to content

Pass group_size to RMSNormGated in Mamba2Simple for ngroups > 1#910

Open
Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Chessing234:fix/mamba2-simple-ngroups-rmsnorm
Open

Pass group_size to RMSNormGated in Mamba2Simple for ngroups > 1#910
Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Chessing234:fix/mamba2-simple-ngroups-rmsnorm

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Bug

Mamba2Simple threads ngroups through the rest of the block — d_in_proj includes 2 * ngroups * d_state, the conv dim is d_inner + 2 * ngroups * d_state, and B/C are rearranged with g=self.ngroups — but the final RMSNormGated is built without group_size:

https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py#L120

```python
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
```

RMSNormGated treats group_size=None as "one group covering the whole hidden_size" (see mamba_ssm/ops/triton/layernorm_gated.py line 418-419), so with ngroups > 1 the gated RMSNorm collapses all heads into a single normalization group, rather than normalizing each group independently.

The non-simple Mamba2 handles this correctly in mamba2.py:

```python
self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
group_size=self.d_ssm // ngroups, **factory_kwargs)
```

Fix

Mirror the mamba2.py pattern and pass group_size=self.d_inner // ngroups. With the default ngroups=1 this evaluates to d_inner, which is equivalent to group_size=None, so the default configuration is unchanged. Users who construct Mamba2Simple(..., ngroups=g) with g > 1 now get group-wise normalization consistent with Mamba2.

Mamba2Simple forwards ngroups throughout the block — d_in_proj, conv,
and the B/C reshape all scale with ngroups — but the final RMSNormGated
is constructed without group_size, which defaults to a single group
spanning d_inner.

The non-simple Mamba2 in mamba2.py handles this correctly by passing
group_size=d_ssm // ngroups. With ngroups=1 (the default) the two paths
are equivalent, but ngroups > 1 silently normalizes across groups that
should be independent.

Mirror the mamba2.py behavior so Mamba2Simple is correct for any
ngroups; the default ngroups=1 case is unchanged since
d_inner // 1 == d_inner is equivalent to group_size=None.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant