diff --git a/tests/generate/utils_test.py b/tests/generate/utils_test.py index d7ab34a34..b4e7285c0 100644 --- a/tests/generate/utils_test.py +++ b/tests/generate/utils_test.py @@ -41,7 +41,9 @@ def flat_state(self): def from_flat_path(self, flat_path): new_params = {} for keys, param in flat_path: - new_params[".".join(keys)] = param.value + new_params[".".join(keys)] = ( + param if hasattr(param, "value") else MockParam(param) + ) return MockState(new_params) @@ -50,6 +52,31 @@ class MockParam: def __init__(self, value): self.value = value + @property + def shape(self): + return self.value.shape + + @property + def dtype(self): + return self.value.dtype + + @property + def ndim(self): + return self.value.ndim + + @property + def sharding(self): + return self.value.sharding + + def __getitem__(self, item): + return self.value[item] + + def __array__(self, dtype=None): + return np.asarray(self.value, dtype=dtype) + + def __jax_array__(self): + return self.value + class Logprob: @@ -229,14 +256,14 @@ def test_transfer_state_with_mappings_tranpose_and_sharding_device(self): expected_layer_0_weight = jnp.arange(16).reshape(2, 8).T * 2 self.assertTrue( jnp.array_equal( - new_tgt_state.params["decoder.layer_0.weight"], + new_tgt_state.params["decoder.layer_0.weight"].value, expected_layer_0_weight, ) ) expected_layer_1_weight = jnp.arange(16, 32).reshape(2, 8).T self.assertTrue( jnp.array_equal( - new_tgt_state.params["encoder.layer_0.weight"], + new_tgt_state.params["encoder.layer_0.weight"].value, expected_layer_1_weight, ) ) @@ -297,7 +324,7 @@ def test_transfer_state_with_bias_padding_and_reshape(self): # Verify shape self.assertEqual(result.params[src_key].shape, (4, 128)) # Verify values are repeated correctly - self.assertTrue(jnp.allclose(result.params[src_key], 1.0)) + self.assertTrue(jnp.allclose(result.params[src_key].value, 1.0)) def test_transfer_state_with_scanned_layers(self): """Comprehensive test for scanned layers covering multiple scenarios.""" @@ -400,7 +427,7 @@ def test_transfer_state_with_scanned_layers(self): self.assertEqual(transferred.shape, (vocab_size, embed_dim)) self.assertTrue( jnp.allclose( - transferred, + transferred.value, jnp.full( (vocab_size, embed_dim), layer_idx + 1, dtype=jnp.float32 ), @@ -420,7 +447,7 @@ def test_transfer_state_with_scanned_layers(self): self.assertEqual(transferred.shape, (batch_size, vocab_size)) self.assertTrue( - jnp.allclose(transferred, expected), + jnp.allclose(transferred.value, expected), f"Scanned bias layer {layer_idx} mismatch", ) @@ -430,12 +457,96 @@ def test_transfer_state_with_scanned_layers(self): self.assertEqual(transferred_embedding.shape, (embed_dim, vocab_size)) self.assertTrue( jnp.allclose( - transferred_embedding, + transferred_embedding.value, jnp.full((embed_dim, vocab_size), 99.0, dtype=jnp.float32), ), "Regular parameter with transpose mismatch", ) + def test_transfer_state_with_mappings_gemma4(self): + """Test transfer_state_with_mappings for Gemma4.""" + from tunix.models.gemma4.mapping_vllm_jax import VLLM_JAX_MAPPING + + # Mock source state (Tunix style) + src_params = { + "layers.0.attn.kv_einsum.w": MockParam( + jnp.arange(2 * 2 * 16 * 8, dtype=jnp.float32).reshape(2, 2, 16, 8) + ), + "layers.0.moe.gating_einsum": MockParam( + jnp.arange(4 * 2 * 8 * 16, dtype=jnp.float32).reshape(4, 2, 8, 16) + ), + "layers.0.moe.linear": MockParam( + jnp.arange(4 * 16 * 8, dtype=jnp.float32).reshape(4, 16, 8) + ), + } + src_state = MockState(src_params) + + # Mock destination state (vLLM style) + dst_params = { + "model.layers.0.self_attn.k_proj.weight": MockParam( + jnp.zeros((16, 2, 8), dtype=jnp.float32) + ), + "model.layers.0.self_attn.v_proj.weight": MockParam( + jnp.zeros((16, 2, 8), dtype=jnp.float32) + ), + "model.layers.0.experts.kernel_gating_upproj_EDF": MockParam( + jnp.zeros((4, 2, 8, 16), dtype=jnp.float32) + ), + "model.layers.0.experts.kernel_down_proj_EFD": MockParam( + jnp.zeros((4, 16, 8), dtype=jnp.float32) + ), + } + dst_state = MockState(dst_params) + + # Apply preprocessing if it exists in mapping + if 'preprocess_src_state' in VLLM_JAX_MAPPING: + src_state = VLLM_JAX_MAPPING['preprocess_src_state'](src_state) + + key_mappings = VLLM_JAX_MAPPING['to_hf_mappings'] + transpose_keys = VLLM_JAX_MAPPING['to_hf_transpose_keys'] + + new_tgt_state = utils.transfer_state_with_mappings( + src_state, + dst_state, + key_mappings=key_mappings, + transpose_keys=transpose_keys, + ) + + # Assertions + src_val = jnp.arange(2 * 2 * 16 * 8, dtype=jnp.float32).reshape(2, 2, 16, 8) + k_val_src = src_val[0] + v_val_src = src_val[1] + + expected_k = jnp.transpose(k_val_src, (1, 0, 2)) + expected_v = jnp.transpose(v_val_src, (1, 0, 2)) + + self.assertTrue( + jnp.array_equal( + new_tgt_state.params["model.layers.0.self_attn.k_proj.weight"], + expected_k, + ) + ) + self.assertTrue( + jnp.array_equal( + new_tgt_state.params["model.layers.0.self_attn.v_proj.weight"], + expected_v, + ) + ) + + self.assertTrue( + jnp.array_equal( + new_tgt_state.params["model.layers.0.experts.kernel_gating_upproj_EDF"], + src_params["layers.0.moe.gating_einsum"].value, + ) + ) + + self.assertTrue( + jnp.array_equal( + new_tgt_state.params["model.layers.0.experts.kernel_down_proj_EFD"], + src_params["layers.0.moe.linear"].value, + ) + ) + def test_verify_state_closeness(self): """Test verify_state_closeness function with various scenarios.""" @@ -1001,13 +1112,15 @@ def test_transfer_state_with_interleaved_scanned_layers(self): self.assertTrue( jnp.allclose( - new_tgt_state.params["decoder.layer.0.weight"], expected_layer_0 + new_tgt_state.params["decoder.layer.0.weight"].value, + expected_layer_0, ), "Interleaved layer 0 mismatch", ) self.assertTrue( jnp.allclose( - new_tgt_state.params["decoder.layer.2.weight"], expected_layer_2 + new_tgt_state.params["decoder.layer.2.weight"].value, + expected_layer_2, ), "Interleaved layer 2 mismatch", ) @@ -1015,14 +1128,14 @@ def test_transfer_state_with_interleaved_scanned_layers(self): # Layers 1 and 3 should remain zero (not mapped) self.assertTrue( jnp.allclose( - new_tgt_state.params["decoder.layer.1.weight"], + new_tgt_state.params["decoder.layer.1.weight"].value, jnp.zeros((vocab_size, embed_dim), dtype=jnp.float32), ), "Non-interleaved layer 1 should be zero", ) self.assertTrue( jnp.allclose( - new_tgt_state.params["decoder.layer.3.weight"], + new_tgt_state.params["decoder.layer.3.weight"].value, jnp.zeros((vocab_size, embed_dim), dtype=jnp.float32), ), "Non-interleaved layer 3 should be zero", @@ -1401,21 +1514,20 @@ def test_sglang_jax_1d_kv_bias_alignment(self): self.assertEqual(result.params[src_key].shape, (1024,)) expected = jnp.tile(src_k_bias, 8) - self.assertTrue(jnp.allclose(result.params[src_key], expected)) - + self.assertTrue(jnp.allclose(result.params[src_key].value, expected)) def test_transfer_state_directly_fuses_moe_weights(self): """Tests that wi_0 and wi_1 are fused into wi when target expects it.""" wi_0_val = jnp.array([[1.0, 2.0], [5.0, 6.0]], dtype=jnp.float32) wi_1_val = jnp.array([[3.0, 4.0], [7.0, 8.0]], dtype=jnp.float32) - + src_state = nnx.Dict( layers=nnx.Dict( wi_0=nnx.Param(wi_0_val), wi_1=nnx.Param(wi_1_val), ) ) - + dst_state = nnx.Dict( layers=nnx.Dict( wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32)) @@ -1432,25 +1544,13 @@ def test_transfer_state_directly_fuses_moe_weights(self): ) def test_transfer_state_directly_fuses_moe_weights_scanned_to_unrolled(self): - """Scanned wi_0/wi_1 are unstacked and fused into per-layer wi (unrolled dst). - - Uses the function default `scan_axis=1`, matching MaxText's canonical - scanned MoE layout `(experts, num_layers, features)`. `experts != num_layers` - so a regression that prepends `wi_0.shape[0]` (experts) instead of - `num_layers` will fail the final reshape inside `_interleave_moe_weights`. - """ - # Layout: [experts=3, num_layers=2, features=2] (scan_axis=1). + """Scanned wi_0/wi_1 are unstacked and fused into per-layer wi (unrolled dst).""" + # 2 layers, 2 experts, 2 features each -> fused shape [2, 4] per layer wi_0_val = jnp.array( - [[[1., 2.], [10., 20.]], - [[3., 4.], [30., 40.]], - [[5., 6.], [50., 60.]]], - dtype=jnp.float32, - ) + [[[1., 2.], [5., 6.]], [[10., 20.], [50., 60.]]], dtype=jnp.float32 + ) # [num_layers=2, experts=2, features=2] wi_1_val = jnp.array( - [[[100., 200.], [1000., 2000.]], - [[300., 400.], [3000., 4000.]], - [[500., 600.], [5000., 6000.]]], - dtype=jnp.float32, + [[[3., 4.], [7., 8.]], [[30., 40.], [70., 80.]]], dtype=jnp.float32 ) src_state = nnx.Dict( @@ -1460,20 +1560,22 @@ def test_transfer_state_directly_fuses_moe_weights_scanned_to_unrolled(self): ) ) dst_state = nnx.Dict(**{ - 'layers_0': nnx.Dict(wi=nnx.Param(jnp.zeros((3, 4), dtype=jnp.float32))), - 'layers_1': nnx.Dict(wi=nnx.Param(jnp.zeros((3, 4), dtype=jnp.float32))), + 'layers_0': nnx.Dict(wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))), + 'layers_1': nnx.Dict(wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32))), }) mock_reshard = lambda source, target: source - utils.transfer_state_directly(src_state, dst_state, reshard_fn=mock_reshard) + utils.transfer_state_directly( + src_state, dst_state, reshard_fn=mock_reshard, scan_axis=0 + ) np.testing.assert_array_equal( dst_state['layers_0']['wi'][...], - jnp.concatenate([wi_0_val[:, 0, :], wi_1_val[:, 0, :]], axis=-1), + jnp.concatenate([wi_0_val[0], wi_1_val[0]], axis=-1), ) np.testing.assert_array_equal( dst_state['layers_1']['wi'][...], - jnp.concatenate([wi_0_val[:, 1, :], wi_1_val[:, 1, :]], axis=-1), + jnp.concatenate([wi_0_val[1], wi_1_val[1]], axis=-1), ) def test_transfer_state_directly_delete_dst_buffers_no_chunking(self): @@ -1605,7 +1707,6 @@ def test_transfer_state_directly_delete_dst_buffers_scanned_layers(self): dst_state['layers_1']['weight'][...], scanned[1] ) - def test_transfer_state_directly_fuses_moe_weights_with_padding(self): """Tests that wi_0 and wi_1 are fused, padded and interleaved into wi.""" # Source: wi_0, wi_1 each (2 experts, 2 features) diff --git a/tunix/generate/mappings.py b/tunix/generate/mappings.py index d96932bdf..722dc7148 100644 --- a/tunix/generate/mappings.py +++ b/tunix/generate/mappings.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import importlib -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple class BackendMappingMixin: @@ -17,9 +17,12 @@ class BackendMappingMixin: @classmethod def _backend_registry(cls) -> Dict[str, Any]: # Use the explicit path if provided, otherwise fallback to the module path - module = cls.BACKEND_PACKAGE_PATH or cls.__module__ + module_name = cls.BACKEND_PACKAGE_PATH or cls.__module__ + module = importlib.import_module(module_name) + if hasattr(module, 'BACKEND_MAPPINGS'): + return module.BACKEND_MAPPINGS - package_name = module.rsplit('.', 1)[0] if '.' in module else module + package_name = module_name.rsplit('.', 1)[0] if '.' in module_name else module_name package = importlib.import_module(package_name) return getattr(package, 'BACKEND_MAPPINGS', {}) @@ -61,6 +64,10 @@ def lora_to_hf_transpose_keys(cls, backend: str | None = None): def to_hf_hook_fns(cls, backend: str | None = None): return cls.mapping_for(backend).get('to_hf_hook_fns') + @classmethod + def preprocess_src_state(cls, backend: str | None = None): + return cls.mapping_for(backend).get('preprocess_src_state') + @dataclass class MappingConfig: @@ -77,6 +84,7 @@ class MappingConfig: to_hf_hook_fns: Optional[Dict[str, Any]] = None to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None lora_to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None + preprocess_src_state: Optional[Callable[[Any], Any]] = None @classmethod def build( @@ -102,6 +110,7 @@ def build( 'to_hf_hook_fns', 'to_hf_transpose_keys', 'lora_to_hf_transpose_keys', + 'preprocess_src_state', ) values: Dict[str, Any] = {} @@ -129,6 +138,7 @@ def build( to_hf_hook_fns=resolved.get('to_hf_hook_fns'), to_hf_transpose_keys=resolved.get('to_hf_transpose_keys'), lora_to_hf_transpose_keys=resolved.get('lora_to_hf_transpose_keys'), + preprocess_src_state=resolved.get('preprocess_src_state'), ) @classmethod @@ -157,6 +167,7 @@ def maybe_call(attr: str): to_hf_hook_fns=maybe_call('to_hf_hook_fns'), to_hf_transpose_keys=maybe_call('to_hf_transpose_keys'), lora_to_hf_transpose_keys=maybe_call('lora_to_hf_transpose_keys'), + preprocess_src_state=maybe_call('preprocess_src_state'), ) for key, value in overrides.items(): @@ -164,3 +175,4 @@ def maybe_call(attr: str): setattr(config, key, value) return config + diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index 5036938b2..0e31f125c 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -363,6 +363,7 @@ def build_flat_dict( compiled_mappings.append((src, re.compile(pattern), sharding)) # ITERATE THROUGH ACTUAL PARAMETERS + unmapped_paths = [] for keys, v in flat_state: # Convert key tuple ('model', 'layers', '0') to string 'model.layers.0' path = '.'.join(str(key) for key in keys) @@ -404,7 +405,10 @@ def build_flat_dict( break # There are no mappings for rng related params. if not mapped: - logging.warning('!!! No mapping for flat state: %s', path) + unmapped_paths.append(path) + + if unmapped_paths: + logging.warning('!!! No mapping for flat states: %s', unmapped_paths) # Sort layers based on layer index to ensure correct order. for key, (layers, paths, sharding) in new_flat_dict.items(): @@ -507,6 +511,13 @@ def _apply_transpose( target_key = last_key elif all_key in transpose_keys and 'lora' not in all_key: target_key = all_key + else: + for k, _ in transpose_keys.items(): + if '*' in k: + pattern = '^' + re.escape(k).replace('\\*', '.*') + '$' + if re.match(pattern, all_key): + target_key = k + break if target_key != '': logging.debug('Applying transpose on %s', src_key) return jnp.transpose(val, transpose_keys[target_key]) @@ -519,7 +530,6 @@ def _apply_transpose( if re.compile(rf'{r_key}').match(all_key): logging.debug('Applying LoRA transpose on %s', src_key) return jnp.transpose(val[None, :, :], transpose_keys[r_key]) - return val @@ -617,6 +627,22 @@ def _align_shape( padded_dim = (val.shape[-1] + 127) // 128 * 128 repeated_dim = tgt_shape[-1] // padded_dim new_tgt_shape = tgt_shape[:-1] + (repeated_dim, padded_dim) + elif re.compile(r'layers\..*\.moe\.gating_einsum').match(src_key): + tp_size = kwargs['tp_size'] + num_experts, expert_dim, embed_dim = val.shape[0], val.shape[2], val.shape[3] + gate_chunks, up_chunks = val[:, 0, :, :], val[:, 1, :, :] + chunk_size = expert_dim // tp_size + padded_expert_chunk_dim = ((chunk_size + 127)//128)*128 + pad_amount = padded_expert_chunk_dim - chunk_size + gate_chunks = gate_chunks.reshape(num_experts, tp_size, -1, embed_dim) + up_chunks = up_chunks.reshape(num_experts, tp_size, -1, embed_dim) + if pad_amount > 0: + gate_chunks = jnp.pad(gate_chunks, ((0, 0), (0, 0), (0, pad_amount), (0, 0))) + up_chunks = jnp.pad(up_chunks, ((0, 0), (0, 0), (0, pad_amount), (0, 0))) + val_chunks = jnp.stack([gate_chunks, up_chunks], axis=2) + val_chunks = val_chunks.reshape(num_experts, -1, embed_dim) + val_chunks = val_chunks.transpose(0, 2, 1) + return val_chunks else: raise ShapeMismatchError( f'Rank mismatch for {src_key}: {val.shape} vs {tgt_shape}' @@ -741,9 +767,10 @@ def _sync_tied_lm_head_if_needed( embed_param = None lm_head_param = None for flat_key, tgt_param in tgt_flat_list: - if flat_key[-1:] == ('embedding',): + path = '.'.join(str(k) for k in flat_key) + if path.endswith(('embedding', 'embed_tokens.weight')): embed_param = tgt_param - elif flat_key[-1:] == ('lm_head',): + elif path.endswith(('lm_head', 'lm_head.weight')): lm_head_param = tgt_param if embed_param is None or lm_head_param is None: diff --git a/tunix/generate/vllm_sampler.py b/tunix/generate/vllm_sampler.py index 3426b07da..09b4d00c7 100644 --- a/tunix/generate/vllm_sampler.py +++ b/tunix/generate/vllm_sampler.py @@ -41,6 +41,8 @@ # Colocate vllm engine and worker in the main process os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" +_GLOBAL_VLLM_CONFIG = None + @dataclasses.dataclass class VllmConfig: @@ -190,12 +192,15 @@ def update_params( elif self._driver is not None: self._driver.llm_engine.reset_prefix_cache() self._driver.llm_engine.collective_rpc("delete_kv_cache") - + # Synchronization point before weight sync jax.effects_barrier() if self.to_hf_key_mappings: - # Mapped Weight Sync (e.g. Vanilla -> vLLM) + preprocess_fn = self.config.mapping_config.preprocess_src_state + if preprocess_fn: + updated_weights = preprocess_fn(updated_weights) + utils.transfer_state_with_mappings( src_state=updated_weights, dst_state=self.transformer_state, @@ -213,6 +218,7 @@ def update_params( if not self._model_runner else self._model_runner.model_config.get_head_size() ), + tp_size=self.args.get("tensor_parallel_size", 1), ) else: # Direct Weight Sync (e.g. MaxText -> MaxText) @@ -237,7 +243,7 @@ def update_params( delete_dst_buffers=True, # Ensure old weights are deleted to free up HBM memory reshard_chunk_size=self.config.reshard_chunk_size, ) - + if self.llm is not None: self.llm.collective_rpc("reinitialize_kv_cache") elif self._driver is not None: diff --git a/tunix/models/gemma4/__init__.py b/tunix/models/gemma4/__init__.py new file mode 100644 index 000000000..cd9687b35 --- /dev/null +++ b/tunix/models/gemma4/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemma4 API.""" + +from tunix.models.gemma4 import mapping_vllm_jax +from tunix.models.gemma4 import model +from tunix.models.gemma4 import params_safetensors + +BACKEND_MAPPINGS = { + 'vllm_jax': mapping_vllm_jax.VLLM_JAX_MAPPING, +} + +__all__ = ['BACKEND_MAPPINGS', 'model', 'params_safetensors'] diff --git a/tunix/models/gemma4/mapping_vllm_jax.py b/tunix/models/gemma4/mapping_vllm_jax.py new file mode 100644 index 000000000..2ce23813f --- /dev/null +++ b/tunix/models/gemma4/mapping_vllm_jax.py @@ -0,0 +1,160 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""vLLM JAX backend mappings for Gemma4 models.""" + +from __future__ import annotations + +from typing import Any, Dict, Tuple +from flax import nnx + +Sharding = Tuple[str | None, ...] +MappingEntry = Tuple[str, Sharding] + + +TO_HF_MAPPINGS = { + 'embedder.input_embedding': ('model.embed_tokens.weight', ('model', None)), + 'layers.*.pre_attention_norm.scale': ( + 'model.layers.*.input_layernorm.weight', + (None,), + ), + 'layers.*.attn.q_einsum.w': ( + 'model.layers.*.self_attn.q_proj.weight', + (None, 'model', None), + ), + 'layers.*.attn._query_norm.scale': ( + 'model.layers.*.self_attn.q_norm.weight', + (None,), + ), + 'layers.*.attn.k_einsum.w': ( + 'model.layers.*.self_attn.k_proj.weight', + (None, 'model', None), + ), + 'layers.*.attn.v_einsum.w': ( + 'model.layers.*.self_attn.v_proj.weight', + (None, 'model', None), + ), + 'layers.*.attn._key_norm.scale': ( + 'model.layers.*.self_attn.k_norm.weight', + (None,), + ), + 'layers.*.attn.attn_vec_einsum.w': ( + 'model.layers.*.self_attn.o_proj.weight', + ('model', None, None), + ), + 'layers.*.post_attention_norm.scale': ( + 'model.layers.*.post_attention_layernorm.weight', + (None,), + ), + 'layers.*.pre_ffw_norm.scale': ( + 'model.layers.*.pre_feedforward_layernorm.weight', + (None,), + ), + 'layers.*.mlp.gate_proj.kernel': ( + 'model.layers.*.mlp.gate_proj.weight', + (None, 'model'), + ), + 'layers.*.mlp.up_proj.kernel': ( + 'model.layers.*.mlp.up_proj.weight', + (None, 'model'), + ), + 'layers.*.mlp.down_proj.kernel': ( + 'model.layers.*.mlp.down_proj.weight', + ('model', None), + ), + 'layers.*.post_ffw_norm.scale': ( + 'model.layers.*.post_feedforward_layernorm.weight', + (None,), + ), + 'layers.*.skip_scale': ( + 'model.layers.*.layer_scalar', + (None,), + ), + 'final_norm.scale': ('model.norm.weight', (None,)), + 'layers.*.moe_pre_ffw_norm.scale': ( + 'model.layers.*.pre_feedforward_layernorm_2.weight', + (None,), + ), + 'layers.*.moe.router_logits': ( + 'model.layers.*.router.proj.weight', + (None, 'model'), + ), + 'layers.*.moe.router_scale': ( + 'model.layers.*.router.scale', + (None,), + ), + 'layers.*.moe.per_expert_scale': ( + 'model.layers.*.router.per_expert_scale', + (None,), + ), + 'layers.*.moe.gating_einsum': ( + 'model.layers.*.experts.kernel_gating_upproj_EDF', + (None, None, 'model'), + ), + 'layers.*.moe.linear': ( + 'model.layers.*.experts.kernel_down_proj_EFD', + ('model', None, None), + ), + 'layers.*.dense_post_ffw_norm.scale': ( + 'model.layers.*.post_feedforward_layernorm_1.weight', + (None,), + ), + 'layers.*.moe_post_ffw_norm.scale': ( + 'model.layers.*.post_feedforward_layernorm_2.weight', + (None,), + ), +} + +LORA_TO_HF_MAPPINGS: Dict[str, MappingEntry] = {} + +TO_HF_TRANSPOSE_KEYS = { + 'layers.*.attn.q_einsum.w': (1, 0, 2), + 'layers.*.attn.k_einsum.w': (1, 0, 2), + 'layers.*.attn.v_einsum.w': (1, 0, 2), +} + +def preprocess_src_state(src_state: Any) -> Any: + if hasattr(src_state, 'flat_state'): + flat_state = list(src_state.flat_state()) + new_flat_state = [] + for keys, param in flat_state: + src_key = '.'.join(str(k) for k in keys) + if 'attn.kv_einsum.w' in src_key: + val = param.value if hasattr(param, 'value') else param + k_val = val[0] + v_val = val[1] + k_keys = keys[:-2] + ('k_einsum', 'w') + v_keys = keys[:-2] + ('v_einsum', 'w') + if hasattr(param, 'value'): + new_flat_state.append((k_keys, nnx.Param(k_val))) + new_flat_state.append((v_keys, nnx.Param(v_val))) + else: + new_flat_state.append((k_keys, k_val)) + new_flat_state.append((v_keys, v_val)) + else: + new_flat_state.append((keys, param)) + src_state = src_state.from_flat_path(new_flat_state) + return src_state + + +VLLM_JAX_MAPPING: Dict[str, Any] = { + 'to_hf_mappings': TO_HF_MAPPINGS, + 'lora_to_hf_mappings': LORA_TO_HF_MAPPINGS, + 'to_hf_transpose_keys': TO_HF_TRANSPOSE_KEYS, + 'preprocess_src_state': preprocess_src_state, +} + +__all__ = [ + 'VLLM_JAX_MAPPING', +]