[Tunix] Support scanned to unscanned weight transfer in transfer_state_directly#1008
Draft
[Tunix] Support scanned to unscanned weight transfer in transfer_state_directly#1008
Conversation
2efc8f5 to
93b4730
Compare
93b4730 to
bbf5b3b
Compare
bbf5b3b to
03cc987
Compare
03cc987 to
4df97fc
Compare
…e_directly Dynamically detecting scan dim + path caching. adding explicit cleanup.
4df97fc to
feb485d
Compare
wang2yn84
reviewed
Feb 3, 2026
|
|
||
|
|
||
| def _slice_scanned_param( | ||
| src_val: Any, tgt_val: Any, slice_idx: int, key_path: str |
Collaborator
There was a problem hiding this comment.
Can you use more detailed types instead of Any?
| def _slice_scanned_param( | ||
| src_val: Any, tgt_val: Any, slice_idx: int, key_path: str | ||
| ) -> Any: | ||
| """Slices a scanned parameter dynamically detecting the scan axis.""" |
Collaborator
There was a problem hiding this comment.
Can you put a more detailed doc string? And maybe also include the input output descriptions?
| src: Any, tgt_spec: Any, path: str = '' | ||
| ) -> Tuple[Any, Any]: | ||
| # Stop recursion if we hit a leaf (non-dict) | ||
| # Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers) |
Collaborator
There was a problem hiding this comment.
Can you fold it into the docstring?
| ) -> Tuple[Any, Any]: | ||
| # Stop recursion if we hit a leaf (non-dict) | ||
| # Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers) | ||
| def intersect_trees(src: Any, tgt_spec: Any) -> Tuple[Any, Any]: |
Collaborator
There was a problem hiding this comment.
I understand it's there before your PR, but can you still add the detailed types?
| # Stop recursion if we hit a leaf (non-dict) | ||
| # Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers) | ||
| def intersect_trees(src: Any, tgt_spec: Any) -> Tuple[Any, Any]: | ||
| """Optimized intersection using flat dictionary traversal.""" |
| try: | ||
| return src_val[slice_idx] | ||
|
|
||
| except (IndexError, TypeError): |
Collaborator
There was a problem hiding this comment.
Add more debugging information in case of this?
| candidate_b.pop(match_index) | ||
| candidate_b = tuple(candidate_b) | ||
|
|
||
| if candidate_b in src_flat: |
Collaborator
There was a problem hiding this comment.
Duplicate code as candidate a, consider make it simpler?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR extends the
transfer_state_directlyutility to support weight synchronization from scanned MaxText model (where layers are stacked in a single tensor) to unscanned MaxText + vLLM models (where layers are separate parameters).Previously,
transfer_state_directlyonly supported 1-to-1 mapping (Unscanned -> Unscanned). This change adds logic to detect and unroll scanned layers during the transfer process.