-
Notifications
You must be signed in to change notification settings - Fork 232
[Tunix] Support scanned to unscanned weight transfer in transfer_state_directly #1008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
2efc8f5 to
93b4730
Compare
93b4730 to
bbf5b3b
Compare
bbf5b3b to
03cc987
Compare
03cc987 to
4df97fc
Compare
…e_directly Dynamically detecting scan dim + path caching. adding explicit cleanup.
4df97fc to
feb485d
Compare
|
|
||
|
|
||
| def _slice_scanned_param( | ||
| src_val: Any, tgt_val: Any, slice_idx: int, key_path: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use more detailed types instead of Any?
| def _slice_scanned_param( | ||
| src_val: Any, tgt_val: Any, slice_idx: int, key_path: str | ||
| ) -> Any: | ||
| """Slices a scanned parameter dynamically detecting the scan axis.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put a more detailed doc string? And maybe also include the input output descriptions?
| src: Any, tgt_spec: Any, path: str = '' | ||
| ) -> Tuple[Any, Any]: | ||
| # Stop recursion if we hit a leaf (non-dict) | ||
| # Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fold it into the docstring?
| ) -> Tuple[Any, Any]: | ||
| # Stop recursion if we hit a leaf (non-dict) | ||
| # Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers) | ||
| def intersect_trees(src: Any, tgt_spec: Any) -> Tuple[Any, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand it's there before your PR, but can you still add the detailed types?
| # Stop recursion if we hit a leaf (non-dict) | ||
| # Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers) | ||
| def intersect_trees(src: Any, tgt_spec: Any) -> Tuple[Any, Any]: | ||
| """Optimized intersection using flat dictionary traversal.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
| try: | ||
| return src_val[slice_idx] | ||
|
|
||
| except (IndexError, TypeError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add more debugging information in case of this?
| candidate_b.pop(match_index) | ||
| candidate_b = tuple(candidate_b) | ||
|
|
||
| if candidate_b in src_flat: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate code as candidate a, consider make it simpler?
This PR extends the
transfer_state_directlyutility to support weight synchronization from scanned MaxText model (where layers are stacked in a single tensor) to unscanned MaxText + vLLM models (where layers are separate parameters).Previously,
transfer_state_directlyonly supported 1-to-1 mapping (Unscanned -> Unscanned). This change adds logic to detect and unroll scanned layers during the transfer process.