diff --git a/tests/generate/utils_test.py b/tests/generate/utils_test.py index 94bfcb572..82576b307 100644 --- a/tests/generate/utils_test.py +++ b/tests/generate/utils_test.py @@ -1403,19 +1403,175 @@ def test_sglang_jax_1d_kv_bias_alignment(self): expected = jnp.tile(src_k_bias, 8) self.assertTrue(jnp.allclose(result.params[src_key], expected)) + def test_is_fused_path(self): + self.assertTrue( + utils.is_fused_path( + "vllm_model.language_model.model.layers.0.self_attn.qkv_proj.weight" + ) + ) + self.assertTrue( + utils.is_fused_path( + "vllm_model.language_model.model.layers.10.mlp.gate_up_proj.weight" + ) + ) + self.assertFalse( + utils.is_fused_path("model.layers.0.self_attn.q_proj.weight") + ) + + def test_fuse_src_to_same_tgt_params_qkv(self): + tgt_path = ( + "vllm_model.language_model.model.layers.0.self_attn.qkv_proj.weight" + ) + q_val = jnp.ones((8, 16, 64)) * 1.0 # (num_heads, d_model, head_dim) + kv_val = ( + jnp.ones((2, 2, 16, 64)) * 2.0 + ) # (2, num_kv_heads, d_model, head_dim) + + fuse_sources = {} + # First call with Q + utils.fuse_src_to_same_tgt_params( + q_val, + "layers.0.attn.q_einsum.w", + fuse_sources, + tgt_path, + None, + tp_size=1, + ) + self.assertLen(fuse_sources[tgt_path], 1) + + # Second call with KV + utils.fuse_src_to_same_tgt_params( + kv_val, + "layers.0.attn.kv_einsum.w", + fuse_sources, + tgt_path, + None, + tp_size=1, + ) + + # Should be fused now + self.assertLen(fuse_sources[tgt_path], 1) + fused_key = "layers.0.attn.qkv_fused" + self.assertIn(fused_key, fuse_sources[tgt_path]) + + fused_val = fuse_sources[tgt_path][fused_key][0] + # Expected shape: (q_per_tp + 2*kv_per_tp * head_dim, d_model) + # -> (d_model, (num_heads + 2*num_kv) * head_dim) + # transposed to ((num_heads+2*kv)*head_dim, d_model) + # q: (8, 16, 64) -> (16, 8, 64) + # kv: (2, 2, 16, 64) -> (16, 2, 2, 64) -> (16, 4, 64) + # concat(q, k, v) -> (16, 12, 64) -> (16, 768) + # transpose -> (768, 16) + self.assertEqual(fused_val.shape, (768, 16)) + self.assertTrue(jnp.allclose(fused_val[:512, :], 1.0)) # Q + self.assertTrue(jnp.allclose(fused_val[512:, :], 2.0)) # KV + + def test_fuse_src_to_same_tgt_params_gate_up(self): + tgt_path = ( + "vllm_model.language_model.model.layers.0.mlp.gate_up_proj.weight" + ) + gate_val = jnp.ones((16, 32)) * 3.0 # (d_model, hidden) + up_val = jnp.ones((16, 32)) * 4.0 # (d_model, hidden) + + fuse_sources = {} + utils.fuse_src_to_same_tgt_params( + gate_val, + "layers.0.mlp.gate_proj.kernel", + fuse_sources, + tgt_path, + None, + tp_size=1, + ) + utils.fuse_src_to_same_tgt_params( + up_val, + "layers.0.mlp.up_proj.kernel", + fuse_sources, + tgt_path, + None, + tp_size=1, + ) + + fused_key = "layers.0.mlp.gate_up_fused" + fused_val = fuse_sources[tgt_path][fused_key][0] + + # Hidden=32. tp_size=1. Gate and Up are stacked: + # (2*Hidden, d_model) = (64, 16) + self.assertEqual(fused_val.shape, (64, 16)) + self.assertTrue(jnp.allclose(fused_val[0:32, :], 3.0)) # gate + self.assertTrue(jnp.allclose(fused_val[32:64, :], 4.0)) # up + + def test_fuse_src_to_same_tgt_params_gate_up_tp2(self): + tgt_path = ( + "vllm_model.language_model.model.layers.0.mlp.gate_up_proj.weight" + ) + gate_val = jnp.ones((16, 32)) * 3.0 + up_val = jnp.ones((16, 32)) * 4.0 + + fuse_sources = {} + utils.fuse_src_to_same_tgt_params( + gate_val, + "layers.0.mlp.gate_proj.kernel", + fuse_sources, + tgt_path, + None, + tp_size=2, + ) + utils.fuse_src_to_same_tgt_params( + up_val, + "layers.0.mlp.up_proj.kernel", + fuse_sources, + tgt_path, + None, + tp_size=2, + ) + + fused_val = fuse_sources[tgt_path]["layers.0.mlp.gate_up_fused"][0] + + # Hidden=32, tp_size=2. chunk_size=16. + # [gate[0:16], up[0:16], gate[16:32], up[16:32]] interleaved + self.assertEqual(fused_val.shape, (64, 16)) + self.assertTrue(jnp.allclose(fused_val[0:16, :], 3.0)) + self.assertTrue(jnp.allclose(fused_val[16:32, :], 4.0)) + self.assertTrue(jnp.allclose(fused_val[32:48, :], 3.0)) + self.assertTrue(jnp.allclose(fused_val[48:64, :], 4.0)) + + def test_align_shape_moe_gating_einsum(self): + val = jnp.ones((2, 2, 128, 16)) + src_key = "layers.0.moe.gating_einsum" + tgt_shape = (2, 16, 256) + + result = utils._align_shape(val, tgt_shape, src_key, tp_size=1) + self.assertEqual(result.shape, (2, 16, 256)) + self.assertTrue(jnp.allclose(result, 1.0)) + + # Test with padding + val_small = jnp.ones((2, 2, 100, 16)) + # chunk_size = 100. padded = 128. pad_amount = 28. + # result shape should be (2, 16, 256) + result_padded = utils._align_shape( + val_small, (2, 16, 256), src_key, tp_size=1 + ) + self.assertEqual(result_padded.shape, (2, 16, 256)) + # Check that padded area is 0 + # gate_chunks (2, 1, 100, 16) -> pad -> (2, 1, 128, 16) stack -> + # (2, 1, 2, 128, 16) -> reshape (2, 256, 16) -> + # transpose (2, 16,256) The first 100 of first 128 should be 1. + # The last 28 of first 128 should be 0. + np.testing.assert_array_equal(result_padded[0, 0, 100:128], 0.0) + np.testing.assert_array_equal(result_padded[0, 0, 0:100], 1.0) def test_transfer_state_directly_fuses_moe_weights(self): """Tests that wi_0 and wi_1 are fused into wi when target expects it.""" wi_0_val = jnp.array([[1.0, 2.0], [5.0, 6.0]], dtype=jnp.float32) wi_1_val = jnp.array([[3.0, 4.0], [7.0, 8.0]], dtype=jnp.float32) - + src_state = nnx.Dict( layers=nnx.Dict( wi_0=nnx.Param(wi_0_val), wi_1=nnx.Param(wi_1_val), ) ) - + dst_state = nnx.Dict( layers=nnx.Dict( wi=nnx.Param(jnp.zeros((2, 4), dtype=jnp.float32)) diff --git a/tunix/generate/mappings.py b/tunix/generate/mappings.py index d96932bdf..d62b7758c 100644 --- a/tunix/generate/mappings.py +++ b/tunix/generate/mappings.py @@ -43,6 +43,15 @@ def to_hf_mappings(cls, backend: str | None = None): ) return mapping + @classmethod + def key_reference_mappings(cls, backend: str | None = None): + mapping = cls.mapping_for(backend).get('key_reference_mappings') + if mapping is None: + raise RuntimeError( + f'{backend} key_reference_mappings missing for {cls.__name__}.' + ) + return mapping + @classmethod def lora_to_hf_mappings(cls, backend: str | None = None): return cls.mapping_for(backend).get('lora_to_hf_mappings') @@ -73,6 +82,7 @@ class MappingConfig: """ to_hf_mappings: Optional[Dict[str, Any]] = None + key_reference_mappings: Optional[Dict[str, Any]] = None lora_to_hf_mappings: Optional[Dict[str, Any]] = None to_hf_hook_fns: Optional[Dict[str, Any]] = None to_hf_transpose_keys: Optional[Dict[str, Tuple[int, ...]]] = None diff --git a/tunix/generate/utils.py b/tunix/generate/utils.py index 14555aa39..b8fc9ba6e 100644 --- a/tunix/generate/utils.py +++ b/tunix/generate/utils.py @@ -30,6 +30,8 @@ import jax.numpy as jnp import numpy as np +QKV_PROJ_PATTERN = r'vllm_model.language_model.model.layers\.\d+\.self_attn\.qkv_proj\.weight' +GATE_UP_PROJ_PATTERN = r'vllm_model.language_model.model.layers\.\d+\.mlp\.gate_up_proj.weight' def compute_attention_masks( time_step: int, seq_len: int, input_mask: jax.Array @@ -326,6 +328,11 @@ def get_logprobs_from_vllm_output( return extracted +def is_fused_path(path): + if re.compile(QKV_PROJ_PATTERN).match(path) or re.compile(GATE_UP_PROJ_PATTERN).match(path): + return True + + def build_flat_dict( flat_state: Iterator[tuple[tuple[str, ...], nnx.State]], mappings: Dict[str, tuple[str, tuple[int, ...]]], @@ -344,6 +351,7 @@ def build_flat_dict( """ new_flat_dict = {} compiled_mappings = [] + fused_tgt_map: Dict[str, Any] = {} # PRE-COMPILE MAPPINGS # Convert target string patterns into Python Regex objects for fast matching. @@ -363,9 +371,10 @@ def build_flat_dict( compiled_mappings.append((src, re.compile(pattern), sharding)) # ITERATE THROUGH ACTUAL PARAMETERS + unmapped_target_keys = [] for keys, v in flat_state: # Convert key tuple ('model', 'layers', '0') to string 'model.layers.0' - path = '.'.join(str(key) for key in keys) + path = keys if isinstance(keys, str) else '.'.join(str(key) for key in keys) mapped = False for src, regex, sharding in compiled_mappings: matched = regex.match(path) @@ -399,12 +408,17 @@ def build_flat_dict( else: # Regular (non-scanned) parameter new_flat_dict[actual_src] = v, path, sharding + if is_fused_path(path): + fused_tgt_map.setdefault(path, []).append(actual_src) mapped = True - break + # Fused path needs to loop over all keys that map to the same target. + if not is_fused_path(path): + break # There are no mappings for rng related params. if not mapped: - logging.warning('!!! No mapping for flat state: %s', path) + unmapped_target_keys.append(path) + logging.warning('!!! No mapping for flat state: %s', unmapped_target_keys) # Sort layers based on layer index to ensure correct order. for key, (layers, paths, sharding) in new_flat_dict.items(): @@ -415,7 +429,7 @@ def build_flat_dict( paths = [p for _, p in paths] new_flat_dict[key] = (values, paths, sharding) - return new_flat_dict + return new_flat_dict, fused_tgt_map class ShapeMismatchError(ValueError): @@ -439,9 +453,71 @@ def _get_layer_axis_from_sharding_spec(sharding_spec) -> Optional[int]: return None +def fuse_src_to_same_tgt_params(src_val, src_key, fuse_sources, tgt_path, tgt_param,tp_size): + """Fuses source parameters for the same target path, performing necessary transpositions and reshaping. This only works for VLLM torchax models.""" + fuse_sources.setdefault(tgt_path, {})[src_key] = (src_val, tgt_param) + if re.compile(QKV_PROJ_PATTERN).match(tgt_path) and len(fuse_sources[tgt_path].items()) == 2: + k, v, q = None, None, None + for sk, (sv, _) in fuse_sources[tgt_path].items(): + if 'kv_einsum' in sk: + k = sv[0] + v = sv[1] + elif 'q_einsum' in sk: + q = sv + elif 'k_einsum' in sk: + k = v = sv + else: + raise MappingError(f"Unexpected source key '{sk}' for target '{tgt_path}'. Expected 'kv_einsum', 'q_einsum', or 'k_einsum'.") + assert k is not None and v is not None and q is not None, f"Failed to find Q, K, V for target '{tgt_path}'." + tp = min(tp_size, k.shape[0]) + kv_per_tp = k.shape[0] // tp + q_per_tp = q.shape[0] // tp + # (num_heads, d_model, head_dim) -> (d_model, num_heads, head_dim) + q, k, v = q.transpose(1, 0, 2), k.transpose(1, 0, 2), v.transpose(1, 0, 2) + head_dim = q.shape[2] + d_model = q.shape[0] + q_by_tp = q.reshape(d_model, tp, q_per_tp, head_dim) + k_by_tp = k.reshape(d_model, tp, kv_per_tp, head_dim) + v_by_tp = v.reshape(d_model, tp, kv_per_tp, head_dim) + qkv_by_tp = jnp.concatenate([q_by_tp, k_by_tp, v_by_tp], axis=2) + qkv = qkv_by_tp.reshape(d_model, -1) + qkv = qkv.transpose(1, 0) + match = re.search(r"layers\.(\d+)\.attn\.(q|k|kv)_einsum\.w", list(fuse_sources[tgt_path].keys())[0]) + assert match, f"Source key '{list(fuse_sources[tgt_path].keys())[0]}' does not match expected pattern for QKV fusion." + layer_idx = match.group(1) + fused_src_key = f"layers.{layer_idx}.attn.qkv_fused" + fuse_sources[tgt_path] = {fused_src_key:(qkv, tgt_param)} + elif re.compile(GATE_UP_PROJ_PATTERN).match(tgt_path) and len(fuse_sources[tgt_path].items()) == 2: + gate, up = None, None + for sk, (sv, _) in fuse_sources[tgt_path].items(): + if 'gate_proj' in sk: + gate = sv + elif 'up_proj' in sk: + up = sv + else: + raise MappingError(f"Unexpected source key '{sk}' for target '{tgt_path}'. Expected 'gate_proj' or 'up_proj'.") + assert gate is not None and up is not None, f"Failed to find gate and up for target '{tgt_path}'." + gate, up = gate.T, up.T + hidden_dim = gate.shape[0] + chunk_size = hidden_dim // tp_size + gate_chunks = gate.reshape(tp_size, chunk_size, gate.shape[1]) + up_chunks = up.reshape(tp_size, chunk_size, up.shape[1]) + gate_up = jnp.stack([gate_chunks, up_chunks], axis=1) + gate_up = gate_up.reshape(2 * hidden_dim, gate.shape[1]) + match = re.search(r"layers\.(\d+)\.mlp\.gate_proj\.kernel", list(fuse_sources[tgt_path].keys())[0]) + assert match, f"Source key '{list(fuse_sources[tgt_path].keys())[0]}' does not match expected pattern for QKV fusion." + layer_idx = match.group(1) + fused_src_key = f"layers.{layer_idx}.mlp.gate_up_fused" + fuse_sources[tgt_path] = {fused_src_key: (gate_up, tgt_param)} + + return fuse_sources + + def _unroll_scanned_layers( src_state: Any, src_to_tgt_map: Dict, + fused_tgt_map: Dict, + tp_size: int, ) -> Dict[Tuple[str, str], Tuple[Any, Any]]: """Unroll scanned layers from source state and map to target keys. @@ -455,6 +531,7 @@ def _unroll_scanned_layers( """ unscanned_flat = {} + tgt_path_to_src_values_fused = {} for src_keys, src_val in src_state.flat_state(): src_key = '.'.join(str(k) for k in src_keys) @@ -485,8 +562,27 @@ def _unroll_scanned_layers( unscanned_flat[(src_key, layer_key)] = (layer_val, tgt_param[i]) else: # No unrolling needed - unscanned_flat[(src_key, tgt_path)] = (src_val.value, tgt_param) - + if tgt_path in fused_tgt_map: + assert src_key in fused_tgt_map[tgt_path], ( + f"Source key '{src_key}' should be part of the fused mapping for" + f" target '{tgt_path}' but it's not. Fused mapping keys:" + f' {fused_tgt_map[tgt_path]}' + ) + tgt_path_to_src_values_fused = fuse_src_to_same_tgt_params( + src_val, + src_key, + tgt_path_to_src_values_fused, + tgt_path, + tgt_param, + tp_size, + ) + else: + unscanned_flat[(src_key, tgt_path)] = (src_val.value, tgt_param) + for tgt_path, src_tgt in tgt_path_to_src_values_fused.items(): + unscanned_flat[(list(src_tgt.keys())[0], tgt_path)] = ( + list(src_tgt.values())[0][0], + list(src_tgt.values())[0][1], + ) return unscanned_flat @@ -501,12 +597,18 @@ def _apply_transpose( return val last_key = src_key.split('.')[-1] + last_three_keys = '.'.join(src_key.split('.')[-3:]) + last_two_keys = '.'.join(src_key.split('.')[-2:]) all_key = src_key target_key = '' if last_key in transpose_keys and 'lora' not in last_key: target_key = last_key elif all_key in transpose_keys and 'lora' not in all_key: target_key = all_key + elif last_three_keys in transpose_keys and 'lora' not in last_three_keys: + target_key = last_three_keys + elif last_two_keys in transpose_keys and 'lora' not in last_two_keys: + target_key = last_two_keys if target_key != '': logging.debug('Applying transpose on %s', src_key) return jnp.transpose(val, transpose_keys[target_key]) @@ -588,6 +690,11 @@ def _align_shape( val = jnp.reshape(val, (kwargs['num_kv_heads'], kwargs['head_dim'])) new_tgt_shape = tgt_shape + elif src_key == 'embedder.per_layer_input_embedding': + return jnp.reshape(val, (val.shape[0], -1)) + elif src_key == 'embedder.per_layer_model_projection.w': + val = jnp.reshape(val, (val.shape[0], -1)) + return val.T elif re.compile(r'layers\..*\.attn\.(q|k|v|o)_proj').match(src_key): if math.prod(tgt_shape) == math.prod(val.shape): logging.debug( @@ -617,6 +724,25 @@ def _align_shape( padded_dim = (val.shape[-1] + 127) // 128 * 128 repeated_dim = tgt_shape[-1] // padded_dim new_tgt_shape = tgt_shape[:-1] + (repeated_dim, padded_dim) + elif re.compile(r'layers\..*\.attn_vec_einsum\.w').match(src_key): + # reshape from (num_head, head_dim, model_dim) to (model_dim, num_head * head_dim) for vec_einsum. + return val.reshape((val.shape[0] * val.shape[1], val.shape[2])).T + elif re.compile(r'layers\..*\.moe\.gating_einsum').match(src_key): + tp_size = kwargs["tp_size"] + num_experts, expert_dim, embed_dim = val.shape[0], val.shape[2], val.shape[3] + gate_chunks, up_chunks = val[:, 0, :, :], val[:, 1, :, :] + chunk_size = expert_dim // tp_size + padded_expert_chunk_dim = ((chunk_size + 127)//128)*128 + pad_amount = padded_expert_chunk_dim - chunk_size + gate_chunks = gate_chunks.reshape(num_experts, tp_size, -1, embed_dim) + up_chunks = up_chunks.reshape(num_experts, tp_size, -1, embed_dim) + if pad_amount > 0: + gate_chunks = jnp.pad(gate_chunks, ((0, 0), (0, 0), (0, pad_amount), (0, 0))) + up_chunks = jnp.pad(up_chunks, ((0, 0), (0, 0), (0, pad_amount), (0, 0))) + val_chunks = jnp.stack([gate_chunks, up_chunks], axis=2) + val_chunks = val_chunks.reshape(num_experts, -1, embed_dim) + val_chunks = val_chunks.transpose(0, 2, 1) + return val_chunks else: raise ShapeMismatchError( f'Rank mismatch for {src_key}: {val.shape} vs {tgt_shape}' @@ -756,10 +882,28 @@ def _sync_tied_lm_head_if_needed( lm_head_param.value = embed_param.value +def flatten_to_tuples(d): + items = [] + key_idx_mapping = {} + i = 0 + for k, v in d.items(): + items.append((k, v)) + key_idx_mapping[k] = i + i += 1 + return items, key_idx_mapping + + +def unflatten_from_tuples(flat_list, dst_state): + for path, value in flat_list: + dst_state[path] = value + return dst_state + + def transfer_state_with_mappings( src_state, dst_state, key_mappings, + key_reference_mappings=None, key_mapping_hook_fns=None, transpose_keys=None, reshard_fn=None, @@ -787,7 +931,11 @@ def transfer_state_with_mappings( The target state with the transferred values. """ # Get flat target state - tgt_flat_list = dst_state.flat_state() + if isinstance(dst_state, dict): + tgt_flat_list, tgt_key_idx_mapping = flatten_to_tuples(dst_state) + else: + tgt_flat_list = dst_state.flat_state() + tgt_key_idx_mapping = None # Build sharding dictionary if resharding is needed sharding_dict = None @@ -803,10 +951,12 @@ def transfer_state_with_mappings( } # Build source-to-target mapping - src_to_tgt_map = build_flat_dict(tgt_flat_list, key_mappings) + src_to_tgt_map, fused_tgt_map = build_flat_dict(tgt_flat_list, key_mappings) # Unroll scanned layers and flatten source state - unscanned_src_to_tgt_flat = _unroll_scanned_layers(src_state, src_to_tgt_map) + unscanned_src_to_tgt_flat = _unroll_scanned_layers( + src_state, src_to_tgt_map, fused_tgt_map, kwargs.get('tp_size', None), + ) transferred_target_keys = set() # Transfer values with transformations @@ -822,15 +972,28 @@ def transfer_state_with_mappings( val = key_mapping_hook_fns[flat_src_key](val) # Align shapes (padding/repeating as needed) + tgt_shape = ( + tgt_param.value.shape + if hasattr(tgt_param, 'value') + else tgt_param.shape + ) + tgt_dtype = ( + tgt_param.value.dtype + if hasattr(tgt_param, 'value') + else tgt_param.dtype + ) val = _align_shape( - val, tgt_param.value.shape, flat_src_key, rollout_engine, **kwargs + val, tgt_shape, flat_src_key, rollout_engine, **kwargs ) # Cast to target dtype - val = _apply_dtype_cast(val, tgt_param.value.dtype, flat_src_key) + val = _apply_dtype_cast(val, tgt_dtype, flat_src_key) # Assign transformed value - tgt_param.value = val + if hasattr(tgt_param, 'value'): + tgt_param.value = val + else: + tgt_flat_list[tgt_key_idx_mapping[flat_tgt_key]] = (flat_tgt_key, val) transferred_target_keys.add(flat_tgt_key) # Target rollout engine might have different implementation and have materialized lm_head @@ -855,8 +1018,21 @@ def transfer_state_with_mappings( if hasattr(tgt_param, 'value'): tgt_param.value = resharded_values_flat_dict[tgt_key] else: - tgt_param = resharded_values_flat_dict[tgt_key] + tgt_flat_list[tgt_key_idx_mapping[tgt_key]] = ( + tgt_key, + resharded_values_flat_dict[tgt_key], + ) + + # handle cases like vllm_model.language_model.lm_head.weight just referencing + # the value of vllm_model.language_model.model.embed_tokens.weight. + if key_reference_mappings: + for tgt_key1, tgt_key2 in key_reference_mappings.items(): + for path, value in tgt_flat_list: + if path == tgt_key1: + tgt_flat_list[tgt_key2] = value + if isinstance(dst_state, dict): + return unflatten_from_tuples(tgt_flat_list, dst_state) return dst_state.from_flat_path(tgt_flat_list) diff --git a/tunix/generate/vllm_sampler.py b/tunix/generate/vllm_sampler.py index 3426b07da..abf836919 100644 --- a/tunix/generate/vllm_sampler.py +++ b/tunix/generate/vllm_sampler.py @@ -156,6 +156,9 @@ def __init__( self.llm = LLM(**self.args) self.to_hf_key_mappings = dict(config.mapping_config.to_hf_mappings or {}) + self.key_reference_mappings = dict( + config.mapping_config.key_reference_mappings or {} + ) self.to_hf_transpose_keys = config.mapping_config.to_hf_transpose_keys self.to_hf_hook_fns = config.mapping_config.to_hf_hook_fns @@ -190,7 +193,7 @@ def update_params( elif self._driver is not None: self._driver.llm_engine.reset_prefix_cache() self._driver.llm_engine.collective_rpc("delete_kv_cache") - + # Synchronization point before weight sync jax.effects_barrier() @@ -200,6 +203,7 @@ def update_params( src_state=updated_weights, dst_state=self.transformer_state, key_mappings=self.to_hf_key_mappings, + key_reference_mappings=self.key_reference_mappings, key_mapping_hook_fns=self.to_hf_hook_fns, transpose_keys=self.to_hf_transpose_keys, reshard_fn=reshard.reshard_pytree, @@ -213,6 +217,7 @@ def update_params( if not self._model_runner else self._model_runner.model_config.get_head_size() ), + tp_size=self.args.get("tensor_parallel_size", 1), ) else: # Direct Weight Sync (e.g. MaxText -> MaxText) @@ -237,7 +242,7 @@ def update_params( delete_dst_buffers=True, # Ensure old weights are deleted to free up HBM memory reshard_chunk_size=self.config.reshard_chunk_size, ) - + if self.llm is not None: self.llm.collective_rpc("reinitialize_kv_cache") elif self._driver is not None: diff --git a/tunix/models/gemma4/__init__.py b/tunix/models/gemma4/__init__.py new file mode 100644 index 000000000..462de1341 --- /dev/null +++ b/tunix/models/gemma4/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemma4 API.""" + +from tunix.models.gemma4 import mapping_vllm_jax +from tunix.models.gemma4 import model +from tunix.models.gemma4 import params_safetensors + +BACKEND_MAPPINGS = { + 'vllm_jax': mapping_vllm_jax.VLLM_JAX_MAPPING, +} + + +__all__ = ['BACKEND_MAPPINGS', 'model', 'params_safetensors'] diff --git a/tunix/models/gemma4/mapping_vllm_jax.py b/tunix/models/gemma4/mapping_vllm_jax.py new file mode 100644 index 000000000..5683f7db2 --- /dev/null +++ b/tunix/models/gemma4/mapping_vllm_jax.py @@ -0,0 +1,169 @@ +"""vLLM JAX backend mappings for Gemma4 models.""" + +from __future__ import annotations + +from typing import Any, Dict, Tuple +import re +from flax.traverse_util import flatten_dict, unflatten_dict +import jax.numpy as jnp + +Sharding = Tuple[str | None, ...] +MappingEntry = Tuple[str, Sharding] + +# Following mappings are only for the torchax implementation of Gemma4. +# The jax implementation needs different mappings. +TO_HF_MAPPINGS: Dict[str, MappingEntry] = { + 'embedder.input_embedding': ( + 'vllm_model.language_model.model.embed_tokens.weight', + ('model', None), + ), + 'layers.*.pre_attention_norm.scale': ( + 'vllm_model.language_model.model.layers.*.input_layernorm.weight', + (None,), + ), + # Q, K, V are fused into one matrix in vLLM. + 'layers.*.attn.q_einsum.w': ( + 'vllm_model.language_model.model.layers.*.self_attn.qkv_proj.weight', + (None, 'model'), + ), + 'layers.*.attn.k_einsum.w': ( + 'vllm_model.language_model.model.layers.*.self_attn.qkv_proj.weight', + (None, 'model'), + ), + 'layers.*.attn.kv_einsum.w': ( + 'vllm_model.language_model.model.layers.*.self_attn.qkv_proj.weight', + (None, 'model'), + ), + 'layers.*.attn.attn_vec_einsum.w': ( + 'vllm_model.language_model.model.layers.*.self_attn.o_proj.weight', + ('model', None, None), + ), + 'layers.*.post_attention_norm.scale': ( + 'vllm_model.language_model.model.layers.*.post_attention_layernorm.weight', + (None,), + ), + 'layers.*.pre_ffw_norm.scale': ( + 'vllm_model.language_model.model.layers.*.pre_feedforward_layernorm.weight', + (None,), + ), + # Gate/Up are fused into one matrix in vLLM. + 'layers.*.mlp.gate_proj.kernel': ( + 'vllm_model.language_model.model.layers.*.mlp.gate_up_proj.weight', + (None, None), + ), + 'layers.*.mlp.up_proj.kernel': ( + 'vllm_model.language_model.model.layers.*.mlp.gate_up_proj.weight', + (None, None), + ), + 'layers.*.mlp.down_proj.kernel': ( + 'vllm_model.language_model.model.layers.*.mlp.down_proj.weight', + ('model', None), + ), + 'layers.*.post_ffw_norm.scale': ( + 'vllm_model.language_model.model.layers.*.post_feedforward_layernorm.weight', + (None,), + ), + 'final_norm.scale': ( + 'vllm_model.language_model.model.norm.weight', + (None,), + ), + 'layers.*.attn._query_norm.scale': ( + 'vllm_model.language_model.model.layers.*.self_attn.q_norm.weight', + (None,), + ), + 'layers.*.attn._key_norm.scale': ( + 'vllm_model.language_model.model.layers.*.self_attn.k_norm.weight', + (None,), + ), + 'layers.*.skip_scale': ( + 'vllm_model.language_model.model.layers.*.layer_scalar', + (None,), + ), +} + +# Add per-layer mappings (used in some Gemma4 variants) +TO_HF_MAPPINGS.update({ + 'embedder.per_layer_input_embedding': ( + 'vllm_model.language_model.model.embed_tokens_per_layer.weight', + ('model', None, None), + ), + 'embedder.per_layer_model_projection.w': ( + 'vllm_model.language_model.model.per_layer_model_projection.weight', + (None, None, 'model'), + ), + 'embedder.per_layer_projection_norm.scale': ( + 'vllm_model.language_model.model.per_layer_projection_norm.weight', + (None,), + ), + 'layers.*.per_layer_input_gate.w': ( + 'vllm_model.language_model.model.layers.*.per_layer_input_gate.weight', + (None, 'model'), + ), + 'layers.*.per_layer_projection.w': ( + 'vllm_model.language_model.model.layers.*.per_layer_projection.weight', + ('model', None), + ), + 'layers.*.post_per_layer_input_norm.scale': ( + 'vllm_model.language_model.model.layers.*.post_per_layer_input_norm.weight', + (None,), + ), +}) + +# Add MoE mappings (used in some Gemma4 variants) +TO_HF_MAPPINGS.update({ + 'layers.*.moe.router_logits': ( + 'vllm_model.language_model.model.layers.*.router.proj.weight', + (None, 'model'), + ), + 'layers.*.moe.per_expert_scale': ( + 'vllm_model.language_model.model.layers.*.router.per_expert_scale', + (None,), + ), + 'layers.*.moe.router_scale': ( + 'vllm_model.language_model.model.layers.*.router.scale', + (None,), + ), + 'layers.*.moe.gating_einsum': ( + 'vllm_model.language_model.model.layers.*.moe.experts.w13_weight', + (None, None, None, 'model'), + ), + 'layers.*.moe.linear': ( + 'vllm_model.language_model.model.layers.*.moe.experts.w2_weight', + (None, 'model', None), + ), + 'layers.*.moe_post_ffw_norm.scale': ( + 'vllm_model.language_model.model.layers.*.post_feedforward_layernorm_2.weight', + (None,), + ), + 'layers.*.moe_pre_ffw_norm.scale': ( + 'vllm_model.language_model.model.layers.*.pre_feedforward_layernorm_2.weight', + (None,), + ), + 'layers.*.dense_post_ffw_norm.scale': ( + 'vllm_model.language_model.model.layers.*.post_feedforward_layernorm_1.weight', + (None,), + ), +}) + +HF_KEY_REFERENCE_MAPPINGS: Dict[str, MappingEntry] = { + 'vllm_model.language_model.lm_head.weight': ( + 'vllm_model.language_model.model.embed_tokens.weight', + (None, None), + ), +} + +VLLM_JAX_MAPPING: Dict[str, Any] = { + 'to_hf_mappings': TO_HF_MAPPINGS, + 'lora_to_hf_mappings': {}, + 'to_hf_transpose_keys': { + 'embedding': (1, 0), + 'mlp.down_proj.kernel': (1, 0), + 'per_layer_input_gate.w': (1, 0), + }, + 'to_hf_hook_fns': None, + 'hf_key_reference_mappings': HF_KEY_REFERENCE_MAPPINGS, +} + +__all__ = [ + 'VLLM_JAX_MAPPING', +] diff --git a/tunix/models/gemma4/params_safetensors.py b/tunix/models/gemma4/params_safetensors.py index d9ca15744..4f3a371cb 100644 --- a/tunix/models/gemma4/params_safetensors.py +++ b/tunix/models/gemma4/params_safetensors.py @@ -48,7 +48,7 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig): None, ), r"(?:model\.language_model\.)?embed_tokens_per_layer\.weight": ( - "embedder.per_layer_input_embedding.value", + "embedder.per_layer_input_embedding", (None, (cfg.num_embed, cfg.num_layers, cfg.per_layer_input_dim)), ), r"(?:model\.language_model\.)?per_layer_model_projection\.weight": ( @@ -198,7 +198,7 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig): None, ), r"(?:model\.language_model\.)?layers\.([0-9]+)\.layer_scalar": ( - r"layers.\1.skip_scale.value", + r"layers.\1.skip_scale", None, ), r"(?:model\.language_model\.)?norm\.weight": ("final_norm.scale", None), @@ -225,7 +225,7 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig): ), # MoE Router and Experts r"(?:model\.language_model\.)?layers\.([0-9]+)\.router\.proj\.weight": ( - r"layers.\1.moe.router_logits.value", + r"layers.\1.moe.router_logits", ((1, 0), None), ), r"(?:model\.language_model\.)?layers\.([0-9]+)\.router\.per_expert_scale": ( @@ -233,15 +233,15 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig): None, ), r"(?:model\.language_model\.)?layers\.([0-9]+)\.router\.scale": ( - r"layers.\1.moe.router_scale.value", + r"layers.\1.moe.router_scale", None, ), r"(?:model\.language_model\.)?layers\.([0-9]+)\.experts\.gate_up_proj(?:\.weight)?": ( - r"layers.\1.moe.gating_einsum.value", + r"layers.\1.moe.gating_einsum", (None, (cfg.num_experts, 2, cfg.expert_dim, cfg.embed_dim)), ), r"(?:model\.language_model\.)?layers\.([0-9]+)\.experts\.down_proj(?:\.weight)?": ( - r"layers.\1.moe.linear.value", + r"layers.\1.moe.linear", ((0, 2, 1), None), ), }