Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 139 additions & 38 deletions tests/generate/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def flat_state(self):
def from_flat_path(self, flat_path):
new_params = {}
for keys, param in flat_path:
new_params[".".join(keys)] = param.value
new_params[".".join(keys)] = (
param if hasattr(param, "value") else MockParam(param)
)
return MockState(new_params)


Expand All @@ -50,6 +52,31 @@ class MockParam:
def __init__(self, value):
self.value = value

@property
def shape(self):
return self.value.shape

@property
def dtype(self):
return self.value.dtype

@property
def ndim(self):
return self.value.ndim

@property
def sharding(self):
return self.value.sharding

def __getitem__(self, item):
return self.value[item]

def __array__(self, dtype=None):
return np.asarray(self.value, dtype=dtype)

def __jax_array__(self):
return self.value


class Logprob:

Expand Down Expand Up @@ -229,14 +256,14 @@ def test_transfer_state_with_mappings_tranpose_and_sharding_device(self):
expected_layer_0_weight = jnp.arange(16).reshape(2, 8).T * 2
self.assertTrue(
jnp.array_equal(
new_tgt_state.params["decoder.layer_0.weight"],
new_tgt_state.params["decoder.layer_0.weight"].value,
expected_layer_0_weight,
)
)
expected_layer_1_weight = jnp.arange(16, 32).reshape(2, 8).T
self.assertTrue(
jnp.array_equal(
new_tgt_state.params["encoder.layer_0.weight"],
new_tgt_state.params["encoder.layer_0.weight"].value,
expected_layer_1_weight,
)
)
Expand Down Expand Up @@ -297,7 +324,7 @@ def test_transfer_state_with_bias_padding_and_reshape(self):
# Verify shape
self.assertEqual(result.params[src_key].shape, (4, 128))
# Verify values are repeated correctly
self.assertTrue(jnp.allclose(result.params[src_key], 1.0))
self.assertTrue(jnp.allclose(result.params[src_key].value, 1.0))

def test_transfer_state_with_scanned_layers(self):
"""Comprehensive test for scanned layers covering multiple scenarios."""
Expand Down Expand Up @@ -400,7 +427,7 @@ def test_transfer_state_with_scanned_layers(self):
self.assertEqual(transferred.shape, (vocab_size, embed_dim))
self.assertTrue(
jnp.allclose(
transferred,
transferred.value,
jnp.full(
(vocab_size, embed_dim), layer_idx + 1, dtype=jnp.float32
),
Expand All @@ -420,7 +447,7 @@ def test_transfer_state_with_scanned_layers(self):

self.assertEqual(transferred.shape, (batch_size, vocab_size))
self.assertTrue(
jnp.allclose(transferred, expected),
jnp.allclose(transferred.value, expected),
f"Scanned bias layer {layer_idx} mismatch",
)

Expand All @@ -430,12 +457,96 @@ def test_transfer_state_with_scanned_layers(self):
self.assertEqual(transferred_embedding.shape, (embed_dim, vocab_size))
self.assertTrue(
jnp.allclose(
transferred_embedding,
transferred_embedding.value,
jnp.full((embed_dim, vocab_size), 99.0, dtype=jnp.float32),
),
"Regular parameter with transpose mismatch",
)

def test_transfer_state_with_mappings_gemma4(self):
"""Test transfer_state_with_mappings for Gemma4."""
from tunix.models.gemma4.mapping_vllm_jax import VLLM_JAX_MAPPING

# Mock source state (Tunix style)
src_params = {
"layers.0.attn.kv_einsum.w": MockParam(
jnp.arange(2 * 2 * 16 * 8, dtype=jnp.float32).reshape(2, 2, 16, 8)
),
"layers.0.moe.gating_einsum": MockParam(
jnp.arange(4 * 2 * 8 * 16, dtype=jnp.float32).reshape(4, 2, 8, 16)
),
"layers.0.moe.linear": MockParam(
jnp.arange(4 * 16 * 8, dtype=jnp.float32).reshape(4, 16, 8)
),
}
src_state = MockState(src_params)

# Mock destination state (vLLM style)
dst_params = {
"model.layers.0.self_attn.k_proj.weight": MockParam(
jnp.zeros((16, 2, 8), dtype=jnp.float32)
),
"model.layers.0.self_attn.v_proj.weight": MockParam(
jnp.zeros((16, 2, 8), dtype=jnp.float32)
),
"model.layers.0.experts.kernel_gating_upproj_EDF": MockParam(
jnp.zeros((4, 2, 8, 16), dtype=jnp.float32)
),
"model.layers.0.experts.kernel_down_proj_EFD": MockParam(
jnp.zeros((4, 16, 8), dtype=jnp.float32)
),
}
dst_state = MockState(dst_params)

# Apply preprocessing if it exists in mapping
if 'preprocess_src_state' in VLLM_JAX_MAPPING:
src_state = VLLM_JAX_MAPPING['preprocess_src_state'](src_state)

key_mappings = VLLM_JAX_MAPPING['to_hf_mappings']
transpose_keys = VLLM_JAX_MAPPING['to_hf_transpose_keys']

new_tgt_state = utils.transfer_state_with_mappings(
src_state,
dst_state,
key_mappings=key_mappings,
transpose_keys=transpose_keys,
)

# Assertions
src_val = jnp.arange(2 * 2 * 16 * 8, dtype=jnp.float32).reshape(2, 2, 16, 8)
k_val_src = src_val[0]
v_val_src = src_val[1]

expected_k = jnp.transpose(k_val_src, (1, 0, 2))
expected_v = jnp.transpose(v_val_src, (1, 0, 2))

self.assertTrue(
jnp.array_equal(
new_tgt_state.params["model.layers.0.self_attn.k_proj.weight"],
expected_k,
)
)
self.assertTrue(
jnp.array_equal(
new_tgt_state.params["model.layers.0.self_attn.v_proj.weight"],
expected_v,
)
)

self.assertTrue(
jnp.array_equal(
new_tgt_state.params["model.layers.0.experts.kernel_gating_upproj_EDF"],
src_params["layers.0.moe.gating_einsum"].value,
)
)

self.assertTrue(
jnp.array_equal(
new_tgt_state.params["model.layers.0.experts.kernel_down_proj_EFD"],
src_params["layers.0.moe.linear"].value,
)
)

def test_verify_state_closeness(self):
"""Test verify_state_closeness function with various scenarios."""

Expand Down Expand Up @@ -1001,28 +1112,30 @@ def test_transfer_state_with_interleaved_scanned_layers(self):

self.assertTrue(
jnp.allclose(
new_tgt_state.params["decoder.layer.0.weight"], expected_layer_0
new_tgt_state.params["decoder.layer.0.weight"].value,
expected_layer_0,
),
"Interleaved layer 0 mismatch",
)
self.assertTrue(
jnp.allclose(
new_tgt_state.params["decoder.layer.2.weight"], expected_layer_2
new_tgt_state.params["decoder.layer.2.weight"].value,
expected_layer_2,
),
"Interleaved layer 2 mismatch",
)

# Layers 1 and 3 should remain zero (not mapped)
self.assertTrue(
jnp.allclose(
new_tgt_state.params["decoder.layer.1.weight"],
new_tgt_state.params["decoder.layer.1.weight"].value,
jnp.zeros((vocab_size, embed_dim), dtype=jnp.float32),
),
"Non-interleaved layer 1 should be zero",
)
self.assertTrue(
jnp.allclose(
new_tgt_state.params["decoder.layer.3.weight"],
new_tgt_state.params["decoder.layer.3.weight"].value,
jnp.zeros((vocab_size, embed_dim), dtype=jnp.float32),
),
"Non-interleaved layer 3 should be zero",
Expand Down Expand Up @@ -1401,21 +1514,20 @@ def test_sglang_jax_1d_kv_bias_alignment(self):

self.assertEqual(result.params[src_key].shape, (1024,))
expected = jnp.tile(src_k_bias, 8)
self.assertTrue(jnp.allclose(result.params[src_key], expected))

self.assertTrue(jnp.allclose(result.params[src_key].value, expected))

def test_transfer_state_directly_fuses_moe_weights(self):
"""Tests that wi_0 and wi_1 are fused into wi when target expects it."""
wi_0_val = jnp.array([[1.0, 2.0], [5.0, 6.0]], dtype=jnp.float32)
wi_1_val = jnp.array([[3.0, 4.0], [7.0, 8.0]], dtype=jnp.float32)

src_state = nnx.Dict(
layers=nnx.Dict(
wi_0=nnx.Param(wi_0_val),
wi_1=nnx.Param(wi_1_val),
)
)

dst_state = nnx.Dict(
layers=nnx.Dict(
wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))
Expand All @@ -1432,25 +1544,13 @@ def test_transfer_state_directly_fuses_moe_weights(self):
)

def test_transfer_state_directly_fuses_moe_weights_scanned_to_unrolled(self):
"""Scanned wi_0/wi_1 are unstacked and fused into per-layer wi (unrolled dst).

Uses the function default `scan_axis=1`, matching MaxText's canonical
scanned MoE layout `(experts, num_layers, features)`. `experts != num_layers`
so a regression that prepends `wi_0.shape[0]` (experts) instead of
`num_layers` will fail the final reshape inside `_interleave_moe_weights`.
"""
# Layout: [experts=3, num_layers=2, features=2] (scan_axis=1).
"""Scanned wi_0/wi_1 are unstacked and fused into per-layer wi (unrolled dst)."""
# 2 layers, 2 experts, 2 features each -> fused shape [2, 4] per layer
wi_0_val = jnp.array(
[[[1., 2.], [10., 20.]],
[[3., 4.], [30., 40.]],
[[5., 6.], [50., 60.]]],
dtype=jnp.float32,
)
[[[1., 2.], [5., 6.]], [[10., 20.], [50., 60.]]], dtype=jnp.float32
) # [num_layers=2, experts=2, features=2]
wi_1_val = jnp.array(
[[[100., 200.], [1000., 2000.]],
[[300., 400.], [3000., 4000.]],
[[500., 600.], [5000., 6000.]]],
dtype=jnp.float32,
[[[3., 4.], [7., 8.]], [[30., 40.], [70., 80.]]], dtype=jnp.float32
)

src_state = nnx.Dict(
Expand All @@ -1460,20 +1560,22 @@ def test_transfer_state_directly_fuses_moe_weights_scanned_to_unrolled(self):
)
)
dst_state = nnx.Dict(**{
'layers_0': nnx.Dict(wi=nnx.Param(jnp.zeros((3, 4), dtype=jnp.float32))),
'layers_1': nnx.Dict(wi=nnx.Param(jnp.zeros((3, 4), dtype=jnp.float32))),
'layers_0': nnx.Dict(wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))),
'layers_1': nnx.Dict(wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))),
})

mock_reshard = lambda source, target: source
utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard)
utils.transfer_state_directly(
src_state, dst_state, reshard_fn=mock_reshard, scan_axis=0
)

np.testing.assert_array_equal(
dst_state['layers_0']['wi'][...],
jnp.concatenate([wi_0_val[:, 0, :], wi_1_val[:, 0, :]], axis=-1),
jnp.concatenate([wi_0_val[0], wi_1_val[0]], axis=-1),
)
np.testing.assert_array_equal(
dst_state['layers_1']['wi'][...],
jnp.concatenate([wi_0_val[:, 1, :], wi_1_val[:, 1, :]], axis=-1),
jnp.concatenate([wi_0_val[1], wi_1_val[1]], axis=-1),
)

def test_transfer_state_directly_delete_dst_buffers_no_chunking(self):
Expand Down Expand Up @@ -1605,7 +1707,6 @@ def test_transfer_state_directly_delete_dst_buffers_scanned_layers(self):
dst_state['layers_1']['weight'][...], scanned[1]
)


def test_transfer_state_directly_fuses_moe_weights_with_padding(self):
"""Tests that wi_0 and wi_1 are fused, padded and interleaved into wi."""
# Source: wi_0, wi_1 each (2 experts, 2 features)
Expand Down
18 changes: 15 additions & 3 deletions tunix/generate/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass
import importlib
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple


class BackendMappingMixin:
Expand All @@ -17,9 +17,12 @@ class BackendMappingMixin:
@classmethod
def _backend_registry(cls) -> Dict[str, Any]:
# Use the explicit path if provided, otherwise fallback to the module path
module = cls.BACKEND_PACKAGE_PATH or cls.__module__
module_name = cls.BACKEND_PACKAGE_PATH or cls.__module__
module = importlib.import_module(module_name)
if hasattr(module, 'BACKEND_MAPPINGS'):
return module.BACKEND_MAPPINGS

package_name = module.rsplit('.', 1)[0] if '.' in module else module
package_name = module_name.rsplit('.', 1)[0] if '.' in module_name else module_name
package = importlib.import_module(package_name)

return getattr(package, 'BACKEND_MAPPINGS', {})
Expand Down Expand Up @@ -61,6 +64,10 @@ def lora_to_hf_transpose_keys(cls, backend: str | None = None):
def to_hf_hook_fns(cls, backend: str | None = None):
return cls.mapping_for(backend).get('to_hf_hook_fns')

@classmethod
def preprocess_src_state(cls, backend: str | None = None):
return cls.mapping_for(backend).get('preprocess_src_state')


@dataclass
class MappingConfig:
Expand All @@ -77,6 +84,7 @@ class MappingConfig:
to_hf_hook_fns: Optional[Dict[str, Any]] = None
to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None
lora_to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None
preprocess_src_state: Optional[Callable[[Any], Any]] = None

@classmethod
def build(
Expand All @@ -102,6 +110,7 @@ def build(
'to_hf_hook_fns',
'to_hf_transpose_keys',
'lora_to_hf_transpose_keys',
'preprocess_src_state',
)

values: Dict[str, Any] = {}
Expand Down Expand Up @@ -129,6 +138,7 @@ def build(
to_hf_hook_fns=resolved.get('to_hf_hook_fns'),
to_hf_transpose_keys=resolved.get('to_hf_transpose_keys'),
lora_to_hf_transpose_keys=resolved.get('lora_to_hf_transpose_keys'),
preprocess_src_state=resolved.get('preprocess_src_state'),
)

@classmethod
Expand Down Expand Up @@ -157,10 +167,12 @@ def maybe_call(attr: str):
to_hf_hook_fns=maybe_call('to_hf_hook_fns'),
to_hf_transpose_keys=maybe_call('to_hf_transpose_keys'),
lora_to_hf_transpose_keys=maybe_call('lora_to_hf_transpose_keys'),
preprocess_src_state=maybe_call('preprocess_src_state'),
)

for key, value in overrides.items():
if hasattr(config, key):
setattr(config, key, value)

return config

Loading
Loading