Skip to content

Conversation

@hanming-lu
Copy link
Collaborator

@hanming-lu hanming-lu commented Dec 10, 2025

Motivation

  • Currently, qwen3-next doesn't support overlap scheduler or branching point caching

Modifications

  1. Support overlap scheduler for qwen3-next
  2. Support branching point caching for qwen3-next
  3. tested for (ps = 1, ps > 1) x (non-spec dec, sd topk1, sd topk>1). All work except for ps > 1 + sd topk > 1, which is not supported on main yet
  4. enable 1) and 2) by --enable-mamba-radix-cache-v2
  5. Better memory allocation for ssm spec dec - instead of coupling spec dec intermediate state size with total ssm states, couple it with max running requests.

Accuracy Tests

  1. Added mamba radix cache KL tests for prefill and decode
  2. cover both --enable-mamba-radix-cache-v2 on and off

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added the npu label Dec 10, 2025
@hanming-lu hanming-lu changed the title [Qwen3-next] Prefix cache for qwen3-next [Qwen3-next] radix cache v2 for qwen3-next Dec 11, 2025
@hanming-lu
Copy link
Collaborator Author

/tag-and-rerun-ci

mask = (req.extend_input_len // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE > 0
mamba_track_mask_cpu.append(mask)
mamba_track_indices_cpu.append(
req.mamba_ping_pong_track_buffer[req.mamba_next_track_idx].item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will cause device/host sync here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will it break overlap schedule?

Copy link
Collaborator Author

@hanming-lu hanming-lu Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked the profile, before it, some functions also triggers sync such as alloc_for_extend(). This new one (the one with arrow) is pretty small compared to existing ones, so no impact.
Screenshot 2025-12-10 at 9 24 08 PM

swa_full_tokens_ratio: float = 0.8
disable_hybrid_swa_memory: bool = False
radix_eviction_policy: str = "lru"
mamba_track_interval: int = 256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate define

cached_tokens = result["meta_info"]["cached_tokens"]
if cache_hit:
assert (
cached_tokens > 0
Copy link
Collaborator

@yizhang2077 yizhang2077 Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the shape of tree in test can be more complex, and cached_tokens we can directly predict

self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids

if get_global_server_args().enable_mamba_radix_cache_v2:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can also be wrapped in _mamba_radix_cache_v2_prepare_for_extend


# copy mamba state to req local space if cow is true
if cow_mamba and last_node.mamba_value is not None:
assert req.req_pool_idx is None # req_pool_idx is uninitialed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this assertion

# does not have a mamba value.
if len(value) > best_value_len:
fla_chunk_aligned_seqlen = (
sum(len(v) for v in value) // FLA_CHUNK_SIZE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we do not use (len(value) // FLA_CHUN_SIZE) * FLA_CHUN_SIZE directly here?

# to retrieve its state from h. Adding 1 will give us the correct index in h,
# otherwise the calculation will retrieve the state from the last_recurrent_state,
# which is not correct.
mamba_track_seqlen = req.mamba_branching_seqlen + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need +1 in branching point while in non-branching point we do not need?

self.last_node,
self.last_host_node,
self.host_hit_length,
self.mamba_branching_seqlen,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when disable radix cache, it will cause error for extra input/output

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants