Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 158 additions & 2 deletions tests/generate/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,19 +1403,175 @@ def test_sglang_jax_1d_kv_bias_alignment(self):
expected = jnp.tile(src_k_bias, 8)
self.assertTrue(jnp.allclose(result.params[src_key], expected))

def test_is_fused_path(self):
self.assertTrue(
utils.is_fused_path(
"vllm_model.language_model.model.layers.0.self_attn.qkv_proj.weight"
)
)
self.assertTrue(
utils.is_fused_path(
"vllm_model.language_model.model.layers.10.mlp.gate_up_proj.weight"
)
)
self.assertFalse(
utils.is_fused_path("model.layers.0.self_attn.q_proj.weight")
)

def test_fuse_src_to_same_tgt_params_qkv(self):
tgt_path = (
"vllm_model.language_model.model.layers.0.self_attn.qkv_proj.weight"
)
q_val = jnp.ones((8, 16, 64)) * 1.0 # (num_heads, d_model, head_dim)
kv_val = (
jnp.ones((2, 2, 16, 64)) * 2.0
) # (2, num_kv_heads, d_model, head_dim)

fuse_sources = {}
# First call with Q
utils.fuse_src_to_same_tgt_params(
q_val,
"layers.0.attn.q_einsum.w",
fuse_sources,
tgt_path,
None,
tp_size=1,
)
self.assertLen(fuse_sources[tgt_path], 1)

# Second call with KV
utils.fuse_src_to_same_tgt_params(
kv_val,
"layers.0.attn.kv_einsum.w",
fuse_sources,
tgt_path,
None,
tp_size=1,
)

# Should be fused now
self.assertLen(fuse_sources[tgt_path], 1)
fused_key = "layers.0.attn.qkv_fused"
self.assertIn(fused_key, fuse_sources[tgt_path])

fused_val = fuse_sources[tgt_path][fused_key][0]
# Expected shape: (q_per_tp + 2*kv_per_tp * head_dim, d_model)
# -> (d_model, (num_heads + 2*num_kv) * head_dim)
# transposed to ((num_heads+2*kv)*head_dim, d_model)
# q: (8, 16, 64) -> (16, 8, 64)
# kv: (2, 2, 16, 64) -> (16, 2, 2, 64) -> (16, 4, 64)
# concat(q, k, v) -> (16, 12, 64) -> (16, 768)
# transpose -> (768, 16)
self.assertEqual(fused_val.shape, (768, 16))
self.assertTrue(jnp.allclose(fused_val[:512, :], 1.0)) # Q
self.assertTrue(jnp.allclose(fused_val[512:, :], 2.0)) # KV

def test_fuse_src_to_same_tgt_params_gate_up(self):
tgt_path = (
"vllm_model.language_model.model.layers.0.mlp.gate_up_proj.weight"
)
gate_val = jnp.ones((16, 32)) * 3.0 # (d_model, hidden)
up_val = jnp.ones((16, 32)) * 4.0 # (d_model, hidden)

fuse_sources = {}
utils.fuse_src_to_same_tgt_params(
gate_val,
"layers.0.mlp.gate_proj.kernel",
fuse_sources,
tgt_path,
None,
tp_size=1,
)
utils.fuse_src_to_same_tgt_params(
up_val,
"layers.0.mlp.up_proj.kernel",
fuse_sources,
tgt_path,
None,
tp_size=1,
)

fused_key = "layers.0.mlp.gate_up_fused"
fused_val = fuse_sources[tgt_path][fused_key][0]

# Hidden=32. tp_size=1. Gate and Up are stacked:
# (2*Hidden, d_model) = (64, 16)
self.assertEqual(fused_val.shape, (64, 16))
self.assertTrue(jnp.allclose(fused_val[0:32, :], 3.0)) # gate
self.assertTrue(jnp.allclose(fused_val[32:64, :], 4.0)) # up

def test_fuse_src_to_same_tgt_params_gate_up_tp2(self):
tgt_path = (
"vllm_model.language_model.model.layers.0.mlp.gate_up_proj.weight"
)
gate_val = jnp.ones((16, 32)) * 3.0
up_val = jnp.ones((16, 32)) * 4.0

fuse_sources = {}
utils.fuse_src_to_same_tgt_params(
gate_val,
"layers.0.mlp.gate_proj.kernel",
fuse_sources,
tgt_path,
None,
tp_size=2,
)
utils.fuse_src_to_same_tgt_params(
up_val,
"layers.0.mlp.up_proj.kernel",
fuse_sources,
tgt_path,
None,
tp_size=2,
)

fused_val = fuse_sources[tgt_path]["layers.0.mlp.gate_up_fused"][0]

# Hidden=32, tp_size=2. chunk_size=16.
# [gate[0:16], up[0:16], gate[16:32], up[16:32]] interleaved
self.assertEqual(fused_val.shape, (64, 16))
self.assertTrue(jnp.allclose(fused_val[0:16, :], 3.0))
self.assertTrue(jnp.allclose(fused_val[16:32, :], 4.0))
self.assertTrue(jnp.allclose(fused_val[32:48, :], 3.0))
self.assertTrue(jnp.allclose(fused_val[48:64, :], 4.0))

def test_align_shape_moe_gating_einsum(self):
val = jnp.ones((2, 2, 128, 16))
src_key = "layers.0.moe.gating_einsum"
tgt_shape = (2, 16, 256)

result = utils._align_shape(val, tgt_shape, src_key, tp_size=1)
self.assertEqual(result.shape, (2, 16, 256))
self.assertTrue(jnp.allclose(result, 1.0))

# Test with padding
val_small = jnp.ones((2, 2, 100, 16))
# chunk_size = 100. padded = 128. pad_amount = 28.
# result shape should be (2, 16, 256)
result_padded = utils._align_shape(
val_small, (2, 16, 256), src_key, tp_size=1
)
self.assertEqual(result_padded.shape, (2, 16, 256))
# Check that padded area is 0
# gate_chunks (2, 1, 100, 16) -> pad -> (2, 1, 128, 16) stack ->
# (2, 1, 2, 128, 16) -> reshape (2, 256, 16) ->
# transpose (2, 16,256) The first 100 of first 128 should be 1.
# The last 28 of first 128 should be 0.
np.testing.assert_array_equal(result_padded[0, 0, 100:128], 0.0)
np.testing.assert_array_equal(result_padded[0, 0, 0:100], 1.0)

def test_transfer_state_directly_fuses_moe_weights(self):
"""Tests that wi_0 and wi_1 are fused into wi when target expects it."""
wi_0_val = jnp.array([[1.0, 2.0], [5.0, 6.0]], dtype=jnp.float32)
wi_1_val = jnp.array([[3.0, 4.0], [7.0, 8.0]], dtype=jnp.float32)

src_state = nnx.Dict(
layers=nnx.Dict(
wi_0=nnx.Param(wi_0_val),
wi_1=nnx.Param(wi_1_val),
)
)

dst_state = nnx.Dict(
layers=nnx.Dict(
wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))
Expand Down
10 changes: 10 additions & 0 deletions tunix/generate/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def to_hf_mappings(cls, backend: str | None = None):
)
return mapping

@classmethod
def key_reference_mappings(cls, backend: str | None = None):
mapping = cls.mapping_for(backend).get('key_reference_mappings')
if mapping is None:
raise RuntimeError(
f'{backend} key_reference_mappings missing for {cls.__name__}.'
)
return mapping

@classmethod
def lora_to_hf_mappings(cls, backend: str | None = None):
return cls.mapping_for(backend).get('lora_to_hf_mappings')
Expand Down Expand Up @@ -73,6 +82,7 @@ class MappingConfig:
"""

to_hf_mappings: Optional[Dict[str, Any]] = None
key_reference_mappings: Optional[Dict[str, Any]] = None
lora_to_hf_mappings: Optional[Dict[str, Any]] = None
to_hf_hook_fns: Optional[Dict[str, Any]] = None
to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None
Expand Down
Loading
Loading