Skip to content

Conversation

@gagika
Copy link
Collaborator

@gagika gagika commented Jan 28, 2026

This PR extends the transfer_state_directly utility to support weight synchronization from scanned MaxText model (where layers are stacked in a single tensor) to unscanned MaxText + vLLM models (where layers are separate parameters).

Previously, transfer_state_directly only supported 1-to-1 mapping (Unscanned -> Unscanned). This change adds logic to detect and unroll scanned layers during the transfer process.

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

…e_directly

Dynamically detecting scan dim + path caching.

adding explicit cleanup.


def _slice_scanned_param(
src_val: Any, tgt_val: Any, slice_idx: int, key_path: str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use more detailed types instead of Any?

def _slice_scanned_param(
src_val: Any, tgt_val: Any, slice_idx: int, key_path: str
) -> Any:
"""Slices a scanned parameter dynamically detecting the scan axis."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a more detailed doc string? And maybe also include the input output descriptions?

src: Any, tgt_spec: Any, path: str = ''
) -> Tuple[Any, Any]:
# Stop recursion if we hit a leaf (non-dict)
# Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fold it into the docstring?

) -> Tuple[Any, Any]:
# Stop recursion if we hit a leaf (non-dict)
# Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers)
def intersect_trees(src: Any, tgt_spec: Any) -> Tuple[Any, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand it's there before your PR, but can you still add the detailed types?

# Stop recursion if we hit a leaf (non-dict)
# Helper: Intersect Trees (Handle KVCache/RNG mismatches and Scanned Layers)
def intersect_trees(src: Any, tgt_spec: Any) -> Tuple[Any, Any]:
"""Optimized intersection using flat dictionary traversal."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

try:
return src_val[slice_idx]

except (IndexError, TypeError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more debugging information in case of this?

candidate_b.pop(match_index)
candidate_b = tuple(candidate_b)

if candidate_b in src_flat:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate code as candidate a, consider make it simpler?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants