From a05150edb87c49ac83dd283d32918d5ca61088b0 Mon Sep 17 00:00:00 2001 From: xiangming Date: Fri, 23 May 2025 10:54:46 +0000 Subject: [PATCH 01/15] This PR extends the codebase of penzai to support gemma3 models. The key changes are as follows: - Add parameters `use_qk_norm`, `local_scale_factor`, `global_scale_factor`, `local_rope_wavelength`, `global_rope_wavelength`, to `llamalike_common.py`. - Add function `_query_norm` and `_key_norm` in `llamalike_common.py` - Add extra arguments `scale_factor` to `pz.nn.ApplyRoPE` in `nn/embeddings.py` - Add parameters for the gemma3 models to `gemma.py`. PiperOrigin-RevId: 762356347 --- penzai/models/transformer/variants/gemma.py | 142 ++++++++++++++--- .../transformer/variants/llamalike_common.py | 143 +++++++++++++----- penzai/nn/embeddings.py | 19 +++ 3 files changed, 244 insertions(+), 60 deletions(-) diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index 1b8289f..1a215b7 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -14,13 +14,14 @@ """The Gemma architecture transformer variant. -Supports both the Gemma 1 and Gemma 2 architectures. Based on the Flax -reference implementation at https://github.com/google-deepmind/gemma. +Supports all the Gemma 1, Gemma 2 and Gemma 3 architectures. Based on the +Flax reference implementation at https://github.com/google-deepmind/gemma. See the Gemma technical reports for more information: * Gemma 1: https://arxiv.org/abs/2403.08295 * Gemma 2: https://arxiv.org/abs/2408.00118 +* Gemma 3: https://arxiv.org/abs/2503.19786 """ from __future__ import annotations @@ -105,6 +106,102 @@ final_logit_softcap=30.0, attn_logits_soft_cap=50.0, ), + "gemma3_1b": dict( + num_decoder_blocks=26, + vocab_size=262_144, + num_kv_heads=1, + query_head_multiplier=4, + embedding_dim=1152, + projection_dim=256, + mlp_hidden_dim=6*1152, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(512), + llamalike_common.AttentionTypeSlidingWindowCausal(512), + llamalike_common.AttentionTypeSlidingWindowCausal(512), + llamalike_common.AttentionTypeSlidingWindowCausal(512), + llamalike_common.AttentionTypeSlidingWindowCausal(512), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_qk_norm=True, + use_post_attn_norm=True, + use_post_ffw_norm=True, + local_rope_wavelength=10_000, + global_rope_wavelength=1_000_000, + ), + "gemma3_4b": dict( + num_decoder_blocks=34, + vocab_size=262_144, + num_kv_heads=4, + query_head_multiplier=2, + embedding_dim=2560, + projection_dim=256, + mlp_hidden_dim=2560 * 8 // 2, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_qk_norm=True, + use_post_attn_norm=True, + use_post_ffw_norm=True, + local_scale_factor=1.0, + global_scale_factor=8.0, + local_rope_wavelength=10_000, + global_rope_wavelength=1_000_000, + ), + "gemma3_12b": dict( + num_decoder_blocks=48, + vocab_size=262_144, + num_kv_heads=8, + query_head_multiplier=2, + embedding_dim=30 * 128, + projection_dim=256, + mlp_hidden_dim=8 * 30 * 128 // 2, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_qk_norm=True, + use_post_attn_norm=True, + use_post_ffw_norm=True, + local_scale_factor=1.0, + global_scale_factor=8.0, + local_rope_wavelength=10_000, + global_rope_wavelength=1_000_000, + ), + "gemma3_27b": dict( + num_decoder_blocks=62, + vocab_size=262_144, + num_kv_heads=16, + query_head_multiplier=2, + embedding_dim=5376, + projection_dim=128, + mlp_hidden_dim=5376 * 8 // 2, + # query scaling factor: 1/sqrt(embedding_dim / num_query_heads) + query_scaling_factor=(5376 // 32) ** -0.5, + attention_type=( + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeSlidingWindowCausal(1024), + llamalike_common.AttentionTypeGlobalCausal(), + ), + use_qk_norm=True, + use_post_attn_norm=True, + use_post_ffw_norm=True, + local_scale_factor=1.0, + global_scale_factor=8.0, + local_rope_wavelength=10_000, + global_rope_wavelength=1_000_000, + ), } _NEEDS_GATING_TRANSPOSE = { "gemma_2b": False, @@ -112,16 +209,21 @@ "gemma2_2b": False, "gemma2_9b": True, "gemma2_27b": True, + "gemma3_1b": True, + "gemma3_4b": True, + "gemma3_12b": True, + "gemma3_27b": True, } def gemma_from_pretrained_checkpoint( ckpt_params: dict[str, Any], + preset_name: Literal[ + "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", + "gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b", + ], upcast_activations_to_float32: bool = False, use_layer_stack: bool = False, - preset_name: Literal[ - "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto" - ] = "auto", ) -> model_parts.TransformerLM: """Builds a Gemma model from a pretrained checkpoint. @@ -139,32 +241,17 @@ def gemma_from_pretrained_checkpoint( Args: ckpt_params: Nested dictionary of weights from the Gemma checkpoint. + preset_name: The name of the Gemma preset to use. upcast_activations_to_float32: Whether to cast activations to float32 when the model runs. This allows analyzing activations at higher precision without consuming additional memory for parameters. use_layer_stack: Whether to use a layer stack for the decoder blocks. - preset_name: Preset name, used to determine model config. If "auto", uses - the number of layers in the checkpoint to determine the configuration. Returns: A Transformer model containing the loaded parameters. """ params = {k.removeprefix("transformer/"): v for k, v in ckpt_params.items()} - if preset_name == "auto": - num_layers = 0 - while f"layer_{num_layers}/mlp/linear" in params: - num_layers += 1 - preset_by_num_layers = { - kwargs["num_decoder_blocks"]: preset_name - for preset_name, kwargs in _GEMMA_PRESETS.items() - } - if num_layers not in preset_by_num_layers: - raise ValueError( - f"Could not determine preset for model with {num_layers} layers." - ) - preset_name = preset_by_num_layers[num_layers] - preset_kwargs = _GEMMA_PRESETS[preset_name] preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name] @@ -207,6 +294,19 @@ def gemma_from_pretrained_checkpoint( 1 + params[f"layer_{i}/pre_attention_norm"]["scale"] ).tag("embedding") ) + # Add qk norm if needed + if config.use_qk_norm: + cur_block_params["attention/_query_norm/scale.weights"] = ( + pz.nx.NamedArray.wrap( + 1 + params[f"layer_{i}/attn/_query_norm"]["scale"] + ).tag("projection") + ) + cur_block_params["attention/_key_norm/scale.weights"] = ( + pz.nx.NamedArray.wrap( + 1 + params[f"layer_{i}/attn/_key_norm"]["scale"] + ).tag("projection") + ) + if config.use_post_attn_norm: cur_block_params["post_attention_norm/scale.weights"] = ( pz.nx.NamedArray.wrap( diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index 1307b23..a39073a 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -34,7 +34,7 @@ import dataclasses import functools from typing import Any, Literal - +from absl import logging import jax import jax.numpy as jnp from penzai import pz @@ -102,6 +102,12 @@ class LlamalikeTransformerConfig: parameter_dtype: Floating dtype to use for all parameters. activation_dtype: Floating dtype to use for activations and KV cache tables. use_layer_stack: Whether to stack the blocks together using a LayerStack. + # NOTE: Gemma3 specific parameters + use_qk_norm: Whether to use QK normalization. + local_scale_factor: Scale factor for the localRoPE layers. + global_scale_factor: Scale factor for the gloabl RoPE layers. + local_rope_wavelength: Wavelength for the local RoPE layers. + global_rope_wavelength: Wavelength for the globalRoPE layers. """ num_kv_heads: int @@ -126,6 +132,12 @@ class LlamalikeTransformerConfig: parameter_dtype: jax.typing.DTypeLike = jnp.float32 activation_dtype: jax.typing.DTypeLike = jnp.float32 use_layer_stack: bool = False + # NOTE: Gemma3 specific parameters + use_qk_norm: bool = False + local_scale_factor: float | None = None + global_scale_factor: float | None = None + local_rope_wavelength: float | None = None + global_rope_wavelength: float | None = None def build_llamalike_feedforward( @@ -261,10 +273,30 @@ def build_llamalike_attention( sliding_window_size=attention_type.window_size, masked_out_value=masked_out_value, ) + # Decide which wavelength to use for local RoPE. + if config.local_rope_wavelength is not None: + wavelength = config.local_rope_wavelength + else: + wavelength = config.rope_wavelength + # Decide which scale factor to use for local RoPE. + if config.local_scale_factor is not None: + scale_factor = config.local_scale_factor + else: + scale_factor = 1.0 elif isinstance(attention_type, AttentionTypeGlobalCausal): attn_masker = pz.nn.ApplyCausalAttentionMask( masked_out_value=masked_out_value, ) + # Decide which wavelength to use for global RoPE. + if config.global_rope_wavelength is not None: + wavelength = config.global_rope_wavelength + else: + wavelength = config.rope_wavelength + # Decide which scale factor to use for global RoPE. + if config.global_scale_factor is not None: + scale_factor = config.global_scale_factor + else: + scale_factor = 1.0 else: raise ValueError(f"Unsupported attention type {attention_type}") @@ -290,42 +322,74 @@ def build_llamalike_attention( pz.nn.Softmax("kv_seq"), ]) + # add qk norm if needed in the module of input_to_query sublayers + input_to_query_sublayers = [ + pz.nn.Linear.from_config( + name=f"{name}/query", + init_base_rng=init_base_rng, + input_axes={"embedding": embedding_dim}, + output_axes={ + **common_head_axes, + **query_only_head_axes, + "projection": projection_dim, + }, + dtype=config.parameter_dtype, + ), + ] + if config.use_qk_norm: + input_to_query_sublayers.append( + pz.nn.RMSLayerNorm.from_config( + name=f"{name}/_query_norm", + init_base_rng=init_base_rng, + across_axes={"projection": config.projection_dim}, + dtype=config.parameter_dtype, + epsilon=config.rms_norm_eps, + ), + ) + input_to_query_sublayers.extend([ + pz.nn.ApplyRoPE( + positions_input_name="token_positions", + embedding_axis="projection", + max_wavelength=wavelength, + scale_factor=scale_factor, + ), + pz.nn.ConstantRescale( + by=jnp.array(query_scaling_factor, dtype=config.activation_dtype) + ), + ]) + + # add qk norm if needed in the module of input_to_key sublayers + input_to_key_sublayers = [ + pz.nn.Linear.from_config( + name=f"{name}/key", + init_base_rng=init_base_rng, + input_axes={"embedding": embedding_dim}, + output_axes={**common_head_axes, "projection": projection_dim}, + dtype=config.parameter_dtype, + ), + ] + if config.use_qk_norm: + input_to_key_sublayers.append( + pz.nn.RMSLayerNorm.from_config( + name=f"{name}/_key_norm", + init_base_rng=init_base_rng, + across_axes={"projection": config.projection_dim}, + dtype=config.parameter_dtype, + epsilon=config.rms_norm_eps, + ), + ) + input_to_key_sublayers.append( + pz.nn.ApplyRoPE( + positions_input_name="token_positions", + embedding_axis="projection", + max_wavelength=wavelength, + scale_factor=scale_factor, + ), + ) + return pz.nn.Attention( - input_to_query=pz.nn.Sequential([ - pz.nn.Linear.from_config( - name=f"{name}/query", - init_base_rng=init_base_rng, - input_axes={"embedding": embedding_dim}, - output_axes={ - **common_head_axes, - **query_only_head_axes, - "projection": projection_dim, - }, - dtype=config.parameter_dtype, - ), - pz.nn.ApplyRoPE( - positions_input_name="token_positions", - embedding_axis="projection", - max_wavelength=config.rope_wavelength, - ), - pz.nn.ConstantRescale( - by=jnp.array(query_scaling_factor, dtype=config.activation_dtype) - ), - ]), - input_to_key=pz.nn.Sequential([ - pz.nn.Linear.from_config( - name=f"{name}/key", - init_base_rng=init_base_rng, - input_axes={"embedding": embedding_dim}, - output_axes={**common_head_axes, "projection": projection_dim}, - dtype=config.parameter_dtype, - ), - pz.nn.ApplyRoPE( - positions_input_name="token_positions", - embedding_axis="projection", - max_wavelength=config.rope_wavelength, - ), - ]), + input_to_query=pz.nn.Sequential(input_to_query_sublayers), + input_to_key=pz.nn.Sequential(input_to_key_sublayers), input_to_value=pz.nn.Sequential([ pz.nn.Linear.from_config( name=f"{name}/value", @@ -483,9 +547,10 @@ def build_llamalike_transformer( else: if not isinstance(config.attention_type, AttentionType): if config.num_decoder_blocks % len(config.attention_type) != 0: - raise ValueError( - "Per-layer attention types must have a length that divides the" - " number of blocks." + logging.warning( + "Please ensure that you are using Gemma3 models." + "For other models, per-layer attention types must have a length " + "that divides the number of blocks." ) for block_index in range(config.num_decoder_blocks): sublayers.append( diff --git a/penzai/nn/embeddings.py b/penzai/nn/embeddings.py index 0905d20..c8ee3dd 100644 --- a/penzai/nn/embeddings.py +++ b/penzai/nn/embeddings.py @@ -225,11 +225,18 @@ class ApplyRoPE(layer_base.Layer): each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis. + # NOTE: add extra arguments to support Gemma3 models. + scale_factor: The scale factor to use for the positional embeddings. """ embedding_axis: str = dataclasses.field(metadata={"pytree_node": False}) max_wavelength: float = dataclasses.field(metadata={"pytree_node": False}) positions_input_name: str = dataclasses.field(metadata={"pytree_node": False}) + # NOTE: add extra arguments to support Gemma3 models. + scale_factor: float = dataclasses.field( + default=1.0, + metadata={"pytree_node": False}, + ) def _apply_1d(self, input_slice: jax.Array, position: jax.Array) -> jax.Array: """Apply RoPE to a one-dimensional JAX array.""" @@ -242,6 +249,10 @@ def _apply_1d(self, input_slice: jax.Array, position: jax.Array) -> jax.Array: # Since we're assuming `timescale` is a vector and `position` is a scalar, # we don't need any axis alignment. sinusoid_inp = position / timescale + # NOTE: add extra arguments to support Gemma3 models. + if self.scale_factor < 1.0: + raise ValueError("scale_factor must be >= 1.0, got {scale_factor") + sinusoid_inp = sinusoid_inp / self.scale_factor sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) first_half, second_half = jnp.split(input_slice, 2) @@ -298,12 +309,19 @@ class ApplyRoPEToSubset(layer_base.Layer): each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis. + # NOTE: add extra arguments to support Gemma3 models. + scale_factor: The scale factor to use for the positional embeddings. """ embedding_axis: str = dataclasses.field(metadata={"pytree_node": False}) max_wavelength: float = dataclasses.field(metadata={"pytree_node": False}) rope_subset_size: int = dataclasses.field(metadata={"pytree_node": False}) positions_input_name: str = dataclasses.field(metadata={"pytree_node": False}) + # NOTE: add extra arguments to support Gemma3 models. + scale_factor: float = dataclasses.field( + default=1.0, + metadata={"pytree_node": False}, + ) def __call__( self, inputs: named_axes.NamedArray, **side_inputs @@ -319,6 +337,7 @@ def __call__( embedding_axis=self.embedding_axis, max_wavelength=self.max_wavelength, positions_input_name=self.positions_input_name, + scale_factor=self.scale_factor, ) rotated_result = rotator(rotary_input, **side_inputs) return named_axes.concatenate( From eb8512bc771fa77964ef275fd96f4c6a5f618071 Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Thu, 12 Jun 2025 10:30:56 +0100 Subject: [PATCH 02/15] Update embeddings.py resolve the comments from Daniel by removing "# NOTE: add extra arguments to support Gemma3 models." and fixing a typo in format string syntax --- penzai/nn/embeddings.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/penzai/nn/embeddings.py b/penzai/nn/embeddings.py index c8ee3dd..0098511 100644 --- a/penzai/nn/embeddings.py +++ b/penzai/nn/embeddings.py @@ -225,14 +225,13 @@ class ApplyRoPE(layer_base.Layer): each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis. - # NOTE: add extra arguments to support Gemma3 models. - scale_factor: The scale factor to use for the positional embeddings. + scale_factor: The scale factor to use for the positional embeddings (used by + Gemma3 models). """ embedding_axis: str = dataclasses.field(metadata={"pytree_node": False}) max_wavelength: float = dataclasses.field(metadata={"pytree_node": False}) positions_input_name: str = dataclasses.field(metadata={"pytree_node": False}) - # NOTE: add extra arguments to support Gemma3 models. scale_factor: float = dataclasses.field( default=1.0, metadata={"pytree_node": False}, @@ -249,9 +248,8 @@ def _apply_1d(self, input_slice: jax.Array, position: jax.Array) -> jax.Array: # Since we're assuming `timescale` is a vector and `position` is a scalar, # we don't need any axis alignment. sinusoid_inp = position / timescale - # NOTE: add extra arguments to support Gemma3 models. if self.scale_factor < 1.0: - raise ValueError("scale_factor must be >= 1.0, got {scale_factor") + raise ValueError("scale_factor must be >= 1.0, got {scale_factor}") sinusoid_inp = sinusoid_inp / self.scale_factor sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) @@ -309,15 +307,14 @@ class ApplyRoPEToSubset(layer_base.Layer): each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis. - # NOTE: add extra arguments to support Gemma3 models. - scale_factor: The scale factor to use for the positional embeddings. + scale_factor: The scale factor to use for the positional embeddings (used by + Gemma 3 models). """ embedding_axis: str = dataclasses.field(metadata={"pytree_node": False}) max_wavelength: float = dataclasses.field(metadata={"pytree_node": False}) rope_subset_size: int = dataclasses.field(metadata={"pytree_node": False}) positions_input_name: str = dataclasses.field(metadata={"pytree_node": False}) - # NOTE: add extra arguments to support Gemma3 models. scale_factor: float = dataclasses.field( default=1.0, metadata={"pytree_node": False}, From f15251ebb00544a98319ba2846acc8c64c2e794b Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Thu, 12 Jun 2025 18:25:14 +0100 Subject: [PATCH 03/15] Update gemma.py resolve the comments from Daniel by enabling "auto" loading gemma 3 models, deleting the leading underscore in qk norm --- penzai/models/transformer/variants/gemma.py | 104 +++++++++++++------- 1 file changed, 68 insertions(+), 36 deletions(-) diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index 1a215b7..6abe4ad 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -34,6 +34,20 @@ from penzai.models.transformer.variants import llamalike_common +def make_attention_layers_types( + pattern: tuple[llamalike_common.AttentionType, ...], + *, + num_layers: int, +) -> tuple[llamalike_common.AttentionType, ...]: + """Returns the list of attention types for every layers.""" + + pattern_size = len(pattern) + out = pattern * (num_layers // pattern_size) + if num_layers % pattern_size != 0: + out += pattern[: num_layers % pattern_size] + return tuple(out) + + _GEMMA_PRESETS = { "gemma_2b": dict( num_decoder_blocks=18, @@ -113,14 +127,11 @@ query_head_multiplier=4, embedding_dim=1152, projection_dim=256, - mlp_hidden_dim=6*1152, - attention_type=( - llamalike_common.AttentionTypeSlidingWindowCausal(512), - llamalike_common.AttentionTypeSlidingWindowCausal(512), - llamalike_common.AttentionTypeSlidingWindowCausal(512), - llamalike_common.AttentionTypeSlidingWindowCausal(512), - llamalike_common.AttentionTypeSlidingWindowCausal(512), - llamalike_common.AttentionTypeGlobalCausal(), + mlp_hidden_dim=6 * 1152, + attention_type=make_attention_layers_types( + pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(512),) + * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), + num_layers=26, ), use_qk_norm=True, use_post_attn_norm=True, @@ -136,13 +147,10 @@ embedding_dim=2560, projection_dim=256, mlp_hidden_dim=2560 * 8 // 2, - attention_type=( - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeGlobalCausal(), + attention_type=make_attention_layers_types( + pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) + * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), + num_layers=34, ), use_qk_norm=True, use_post_attn_norm=True, @@ -160,13 +168,10 @@ embedding_dim=30 * 128, projection_dim=256, mlp_hidden_dim=8 * 30 * 128 // 2, - attention_type=( - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeGlobalCausal(), + attention_type=make_attention_layers_types( + pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) + * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), + num_layers=48, ), use_qk_norm=True, use_post_attn_norm=True, @@ -186,13 +191,10 @@ mlp_hidden_dim=5376 * 8 // 2, # query scaling factor: 1/sqrt(embedding_dim / num_query_heads) query_scaling_factor=(5376 // 32) ** -0.5, - attention_type=( - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeSlidingWindowCausal(1024), - llamalike_common.AttentionTypeGlobalCausal(), + attention_type=make_attention_layers_types( + pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) + * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), + num_layers=34, ), use_qk_norm=True, use_post_attn_norm=True, @@ -218,12 +220,12 @@ def gemma_from_pretrained_checkpoint( ckpt_params: dict[str, Any], - preset_name: Literal[ - "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", - "gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b", - ], upcast_activations_to_float32: bool = False, use_layer_stack: bool = False, + preset_name: Literal[ + "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", + "gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b", "auto" + ] = "auto", ) -> model_parts.TransformerLM: """Builds a Gemma model from a pretrained checkpoint. @@ -241,17 +243,47 @@ def gemma_from_pretrained_checkpoint( Args: ckpt_params: Nested dictionary of weights from the Gemma checkpoint. - preset_name: The name of the Gemma preset to use. upcast_activations_to_float32: Whether to cast activations to float32 when the model runs. This allows analyzing activations at higher precision without consuming additional memory for parameters. use_layer_stack: Whether to use a layer stack for the decoder blocks. + preset_name: Preset name, used to determine model config. If "auto", uses + the number of layers and whether the model needs qk norm in the checkpoint + to determine the configuration. Returns: A Transformer model containing the loaded parameters. """ params = {k.removeprefix("transformer/"): v for k, v in ckpt_params.items()} + if preset_name == "auto": + num_layers = 0 + while f"layer_{num_layers}/mlp/linear" in params: + num_layers += 1 + if ( + "layer_0/attn/_query_norm" in params + and "layer_0/attn/_key_norm" in params + ): + qk_norm = True + else: + qk_norm = False + is_match = False + for gemma_preset_name, kwargs in _GEMMA_PRESETS.items(): + if kwargs["num_decoder_blocks"] == num_layers: + if qk_norm and "use_qk_norm" in kwargs: + is_match = True + preset_name = gemma_preset_name + break + if (not qk_norm) and ("use_qk_norm" not in kwargs): + is_match = True + preset_name = gemma_preset_name + break + if not is_match: + raise ValueError( + f"Could not determine preset for model with {num_layers} layers and" + f" qk norm {qk_norm}." + ) + preset_kwargs = _GEMMA_PRESETS[preset_name] preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name] @@ -296,12 +328,12 @@ def gemma_from_pretrained_checkpoint( ) # Add qk norm if needed if config.use_qk_norm: - cur_block_params["attention/_query_norm/scale.weights"] = ( + cur_block_params["attention/query_norm/scale.weights"] = ( pz.nx.NamedArray.wrap( 1 + params[f"layer_{i}/attn/_query_norm"]["scale"] ).tag("projection") ) - cur_block_params["attention/_key_norm/scale.weights"] = ( + cur_block_params["attention/key_norm/scale.weights"] = ( pz.nx.NamedArray.wrap( 1 + params[f"layer_{i}/attn/_key_norm"]["scale"] ).tag("projection") From cbeb885976ac5a1bf501bc970f3909001439c83a Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Thu, 12 Jun 2025 18:30:49 +0100 Subject: [PATCH 04/15] Update llamalike_common.py resolve the comments from Daniel by deleting leading underscore for qk norm, remaining the check for attention types being divided by number of blocks. --- .../transformer/variants/llamalike_common.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index a39073a..6b32e69 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -34,7 +34,6 @@ import dataclasses import functools from typing import Any, Literal -from absl import logging import jax import jax.numpy as jnp from penzai import pz @@ -102,12 +101,11 @@ class LlamalikeTransformerConfig: parameter_dtype: Floating dtype to use for all parameters. activation_dtype: Floating dtype to use for activations and KV cache tables. use_layer_stack: Whether to stack the blocks together using a LayerStack. - # NOTE: Gemma3 specific parameters use_qk_norm: Whether to use QK normalization. - local_scale_factor: Scale factor for the localRoPE layers. + local_scale_factor: Scale factor for the local RoPE layers. global_scale_factor: Scale factor for the gloabl RoPE layers. local_rope_wavelength: Wavelength for the local RoPE layers. - global_rope_wavelength: Wavelength for the globalRoPE layers. + global_rope_wavelength: Wavelength for the global RoPE layers. """ num_kv_heads: int @@ -132,7 +130,6 @@ class LlamalikeTransformerConfig: parameter_dtype: jax.typing.DTypeLike = jnp.float32 activation_dtype: jax.typing.DTypeLike = jnp.float32 use_layer_stack: bool = False - # NOTE: Gemma3 specific parameters use_qk_norm: bool = False local_scale_factor: float | None = None global_scale_factor: float | None = None @@ -339,7 +336,7 @@ def build_llamalike_attention( if config.use_qk_norm: input_to_query_sublayers.append( pz.nn.RMSLayerNorm.from_config( - name=f"{name}/_query_norm", + name=f"{name}/query_norm", init_base_rng=init_base_rng, across_axes={"projection": config.projection_dim}, dtype=config.parameter_dtype, @@ -371,7 +368,7 @@ def build_llamalike_attention( if config.use_qk_norm: input_to_key_sublayers.append( pz.nn.RMSLayerNorm.from_config( - name=f"{name}/_key_norm", + name=f"{name}/key_norm", init_base_rng=init_base_rng, across_axes={"projection": config.projection_dim}, dtype=config.parameter_dtype, @@ -547,10 +544,9 @@ def build_llamalike_transformer( else: if not isinstance(config.attention_type, AttentionType): if config.num_decoder_blocks % len(config.attention_type) != 0: - logging.warning( - "Please ensure that you are using Gemma3 models." - "For other models, per-layer attention types must have a length " - "that divides the number of blocks." + raise ValueError( + "Per-layer attention types must have a length that divides the" + " number of blocks." ) for block_index in range(config.num_decoder_blocks): sublayers.append( From f2421fe9c82bf65483b7706ffef73f3f149897ad Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Thu, 12 Jun 2025 18:43:50 +0100 Subject: [PATCH 05/15] Update howto_reference.md add instructions to load gemma3 models --- docs/guides/howto_reference.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/guides/howto_reference.md b/docs/guides/howto_reference.md index cddf53c..0d74004 100644 --- a/docs/guides/howto_reference.md +++ b/docs/guides/howto_reference.md @@ -217,9 +217,9 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen ## Loading Pretrained Models -### Loading Gemma or Gemma 2 +### Loading Gemma or Gemma 2 or Gemma 3 -Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2)) into the correct form. You can load it using: +Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2), [Gemma 3](https://www.kaggle.com/models/google/gemma-3)) into the correct form. You can load it using: ```python import kagglehub @@ -236,13 +236,20 @@ flax_params_dict = checkpointer.restore(ckpt_path) model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict) ``` -To load Gemma 2, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use: +To load Gemma 2/3, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use: ```python weights_dir = kagglehub.model_download('google/gemma-2/flax/gemma2-9b') ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt') ``` +For instance, to load the Gemma 3 4B model, you can use: + +```python +weights_dir = kagglehub.model_download('google/gemma-3/flax/gemma3-4b') +ckpt_path = os.path.join(weights_dir, 'gemma3_4b_pt') +``` + See the "Model Variations" section on the Kaggle model pages for details about the names and paths for each checkpoint. (You may also need to create a Kaggle account and request access to each model before you can download the checkpoints.) If you are using multiple accelerator devices (e.g. for a TPU v2 Colab kernel), you may want to shard the parameters over the devices while loading them. To do so, you can pass a sharding specification to `orbax.checkpoint`. For instance, to shard over the last axis of every parameter, you can use From 97c9e95f0eaa9cdb4ec358a27c6bc22e9edb7f1c Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Sat, 14 Jun 2025 13:17:57 +0100 Subject: [PATCH 06/15] Update gemma.py --- penzai/models/transformer/variants/gemma.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index 6abe4ad..38fe0d1 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -271,9 +271,10 @@ def gemma_from_pretrained_checkpoint( for gemma_preset_name, kwargs in _GEMMA_PRESETS.items(): if kwargs["num_decoder_blocks"] == num_layers: if qk_norm and "use_qk_norm" in kwargs: - is_match = True - preset_name = gemma_preset_name - break + if kwargs["use_qk_norm"]: + is_match = True + preset_name = gemma_preset_name + break if (not qk_norm) and ("use_qk_norm" not in kwargs): is_match = True preset_name = gemma_preset_name From 3dfc906dfdba6e8279a94443c370b094880f38bc Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Sun, 15 Jun 2025 23:05:27 +0100 Subject: [PATCH 07/15] Update howto_reference.md --- docs/guides/howto_reference.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/guides/howto_reference.md b/docs/guides/howto_reference.md index 0d74004..e2c1c21 100644 --- a/docs/guides/howto_reference.md +++ b/docs/guides/howto_reference.md @@ -217,7 +217,7 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen ## Loading Pretrained Models -### Loading Gemma or Gemma 2 or Gemma 3 +### Loading Gemma (1, 2, or 3) Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2), [Gemma 3](https://www.kaggle.com/models/google/gemma-3)) into the correct form. You can load it using: @@ -243,7 +243,7 @@ weights_dir = kagglehub.model_download('google/gemma-2/flax/gemma2-9b') ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt') ``` -For instance, to load the Gemma 3 4B model, you can use: +To load the Gemma 3 4B model, you can use: ```python weights_dir = kagglehub.model_download('google/gemma-3/flax/gemma3-4b') From 0be6eec1da40f5fe146a0b82c7c9047edffd4b6f Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Sun, 15 Jun 2025 23:57:03 +0100 Subject: [PATCH 08/15] Update gemma.py --- penzai/models/transformer/variants/gemma.py | 24 +++++++++------------ 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index 38fe0d1..b99c532 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -34,7 +34,7 @@ from penzai.models.transformer.variants import llamalike_common -def make_attention_layers_types( +def _make_attention_layers_types( pattern: tuple[llamalike_common.AttentionType, ...], *, num_layers: int, @@ -128,7 +128,7 @@ def make_attention_layers_types( embedding_dim=1152, projection_dim=256, mlp_hidden_dim=6 * 1152, - attention_type=make_attention_layers_types( + attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(512),) * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), num_layers=26, @@ -136,8 +136,8 @@ def make_attention_layers_types( use_qk_norm=True, use_post_attn_norm=True, use_post_ffw_norm=True, + rope_wavelength=1_000_000, local_rope_wavelength=10_000, - global_rope_wavelength=1_000_000, ), "gemma3_4b": dict( num_decoder_blocks=34, @@ -147,7 +147,7 @@ def make_attention_layers_types( embedding_dim=2560, projection_dim=256, mlp_hidden_dim=2560 * 8 // 2, - attention_type=make_attention_layers_types( + attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), num_layers=34, @@ -155,10 +155,9 @@ def make_attention_layers_types( use_qk_norm=True, use_post_attn_norm=True, use_post_ffw_norm=True, - local_scale_factor=1.0, global_scale_factor=8.0, + rope_wavelength=1_000_000, local_rope_wavelength=10_000, - global_rope_wavelength=1_000_000, ), "gemma3_12b": dict( num_decoder_blocks=48, @@ -168,7 +167,7 @@ def make_attention_layers_types( embedding_dim=30 * 128, projection_dim=256, mlp_hidden_dim=8 * 30 * 128 // 2, - attention_type=make_attention_layers_types( + attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), num_layers=48, @@ -176,10 +175,9 @@ def make_attention_layers_types( use_qk_norm=True, use_post_attn_norm=True, use_post_ffw_norm=True, - local_scale_factor=1.0, global_scale_factor=8.0, + rope_wavelength=1_000_000, local_rope_wavelength=10_000, - global_rope_wavelength=1_000_000, ), "gemma3_27b": dict( num_decoder_blocks=62, @@ -191,18 +189,17 @@ def make_attention_layers_types( mlp_hidden_dim=5376 * 8 // 2, # query scaling factor: 1/sqrt(embedding_dim / num_query_heads) query_scaling_factor=(5376 // 32) ** -0.5, - attention_type=make_attention_layers_types( + attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), - num_layers=34, + num_layers=62, ), use_qk_norm=True, use_post_attn_norm=True, use_post_ffw_norm=True, - local_scale_factor=1.0, global_scale_factor=8.0, + rope_wavelength=1_000_000, local_rope_wavelength=10_000, - global_rope_wavelength=1_000_000, ), } _NEEDS_GATING_TRANSPOSE = { @@ -299,7 +296,6 @@ def gemma_from_pretrained_checkpoint( **preset_kwargs, parameter_dtype=parameter_dtype, mlp_variant="geglu_approx", - rope_wavelength=10_000, tie_embedder_and_logits=True, activation_dtype=activation_dtype, use_layer_stack=use_layer_stack, From 0e65e855844d96c1eaed3cdf81b278599259b6c1 Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Sun, 15 Jun 2025 23:57:26 +0100 Subject: [PATCH 09/15] Update llamalike_common.py --- .../transformer/variants/llamalike_common.py | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index 6b32e69..9a35774 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -83,7 +83,8 @@ class LlamalikeTransformerConfig: tie_embedder_and_logits: Whether to tie the weights of the input token embedding and output logit layers. If True, also scales down input token embeddings by sqrt(embedding_dim). (This is used by Gemma.) - rope_wavelength: Wavelength for RoPE layers. + rope_wavelength: Wavelength for global RoPE layers (and for local RoPE + layers if local_rope_wavelength is not set). rms_norm_eps: Epsilon for RMSNorm layers. attention_type: A single attention type or sequence of per-layer attention types. If a sequence, its length should evenly divide the number of @@ -102,10 +103,11 @@ class LlamalikeTransformerConfig: activation_dtype: Floating dtype to use for activations and KV cache tables. use_layer_stack: Whether to stack the blocks together using a LayerStack. use_qk_norm: Whether to use QK normalization. - local_scale_factor: Scale factor for the local RoPE layers. - global_scale_factor: Scale factor for the gloabl RoPE layers. - local_rope_wavelength: Wavelength for the local RoPE layers. - global_rope_wavelength: Wavelength for the global RoPE layers. + global_scale_factor: Scale factor for the gloabl RoPE layers (scale factor + for the local RoPE layers is set as 1.0 by default). + local_rope_wavelength: Wavelength for the local RoPE layers. If None, local + RoPE layers will use the same wavelength as global RoPE layers + (config.rope_wavelength). """ num_kv_heads: int @@ -131,10 +133,8 @@ class LlamalikeTransformerConfig: activation_dtype: jax.typing.DTypeLike = jnp.float32 use_layer_stack: bool = False use_qk_norm: bool = False - local_scale_factor: float | None = None global_scale_factor: float | None = None local_rope_wavelength: float | None = None - global_rope_wavelength: float | None = None def build_llamalike_feedforward( @@ -275,20 +275,12 @@ def build_llamalike_attention( wavelength = config.local_rope_wavelength else: wavelength = config.rope_wavelength - # Decide which scale factor to use for local RoPE. - if config.local_scale_factor is not None: - scale_factor = config.local_scale_factor - else: - scale_factor = 1.0 + scale_factor = 1.0 elif isinstance(attention_type, AttentionTypeGlobalCausal): attn_masker = pz.nn.ApplyCausalAttentionMask( masked_out_value=masked_out_value, ) - # Decide which wavelength to use for global RoPE. - if config.global_rope_wavelength is not None: - wavelength = config.global_rope_wavelength - else: - wavelength = config.rope_wavelength + wavelength = config.rope_wavelength # Decide which scale factor to use for global RoPE. if config.global_scale_factor is not None: scale_factor = config.global_scale_factor From 15b116330ffac6926d08097e8f28f69145b76e09 Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Mon, 16 Jun 2025 00:23:36 +0100 Subject: [PATCH 10/15] Update gemma.py --- penzai/models/transformer/variants/gemma.py | 36 +++++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index b99c532..c1eae1d 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -130,8 +130,12 @@ def _make_attention_layers_types( mlp_hidden_dim=6 * 1152, attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(512),) - * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), + * 5 + + ( + llamalike_common.AttentionTypeGlobalCausal(), + ), num_layers=26, + ), use_qk_norm=True, use_post_attn_norm=True, @@ -149,8 +153,12 @@ def _make_attention_layers_types( mlp_hidden_dim=2560 * 8 // 2, attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) - * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), + * 5 + + ( + llamalike_common.AttentionTypeGlobalCausal(), + ), num_layers=34, + ), use_qk_norm=True, use_post_attn_norm=True, @@ -169,8 +177,12 @@ def _make_attention_layers_types( mlp_hidden_dim=8 * 30 * 128 // 2, attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) - * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), + * 5 + + ( + llamalike_common.AttentionTypeGlobalCausal(), + ), num_layers=48, + ), use_qk_norm=True, use_post_attn_norm=True, @@ -191,8 +203,12 @@ def _make_attention_layers_types( query_scaling_factor=(5376 // 32) ** -0.5, attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) - * 5 + (llamalike_common.AttentionTypeGlobalCausal(),), + * 5 + + ( + llamalike_common.AttentionTypeGlobalCausal(), + ), num_layers=62, + ), use_qk_norm=True, use_post_attn_norm=True, @@ -220,8 +236,16 @@ def gemma_from_pretrained_checkpoint( upcast_activations_to_float32: bool = False, use_layer_stack: bool = False, preset_name: Literal[ - "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", - "gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b", "auto" + "gemma_2b", + "gemma_7b", + "gemma2_2b", + "gemma2_9b", + "gemma2_27b", + "gemma3_1b", + "gemma3_4b", + "gemma3_12b", + "gemma3_27b", + "auto", ] = "auto", ) -> model_parts.TransformerLM: """Builds a Gemma model from a pretrained checkpoint. From d0253447b4d4fd2e23dad4b119789c69a2127d44 Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Mon, 16 Jun 2025 00:23:55 +0100 Subject: [PATCH 11/15] Update llamalike_common.py --- penzai/models/transformer/variants/llamalike_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index 9a35774..321c500 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -103,8 +103,8 @@ class LlamalikeTransformerConfig: activation_dtype: Floating dtype to use for activations and KV cache tables. use_layer_stack: Whether to stack the blocks together using a LayerStack. use_qk_norm: Whether to use QK normalization. - global_scale_factor: Scale factor for the gloabl RoPE layers (scale factor - for the local RoPE layers is set as 1.0 by default). + global_scale_factor: Scale factor for the global RoPE layers (scale factor + for the local RoPE layers is set as 1.0 by default). local_rope_wavelength: Wavelength for the local RoPE layers. If None, local RoPE layers will use the same wavelength as global RoPE layers (config.rope_wavelength). From 4659231e0efd412377279ea4f8bd7e22dfd96736 Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Mon, 16 Jun 2025 10:06:05 +0100 Subject: [PATCH 12/15] Update llamalike_common.py --- .../transformer/variants/llamalike_common.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index 321c500..02870e0 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -34,6 +34,7 @@ import dataclasses import functools from typing import Any, Literal + import jax import jax.numpy as jnp from penzai import pz @@ -104,7 +105,7 @@ class LlamalikeTransformerConfig: use_layer_stack: Whether to stack the blocks together using a LayerStack. use_qk_norm: Whether to use QK normalization. global_scale_factor: Scale factor for the global RoPE layers (scale factor - for the local RoPE layers is set as 1.0 by default). + for the local RoPE layers is set as 1.0 by default). local_rope_wavelength: Wavelength for the local RoPE layers. If None, local RoPE layers will use the same wavelength as global RoPE layers (config.rope_wavelength). @@ -665,14 +666,18 @@ def llamalike_from_huggingface_model( converted = {k: jax.dlpack.from_dlpack(v) for k, v in state_dict.items()} parameter_mapping = { - "embedder.embeddings": pz.nx.NamedArray.wrap( - converted["model.embed_tokens.weight"] - ).tag("vocabulary", "embedding"), - "final_norm/scale.weights": pz.nx.NamedArray.wrap( - converted["model.norm.weight"] - ).tag("embedding"), - "lm_head.weights": pz.nx.NamedArray.wrap(converted["lm_head.weight"]).tag( - "vocabulary", "embedding" + "embedder.embeddings": ( + pz.nx.NamedArray.wrap(converted["model.embed_tokens.weight"]).tag( + "vocabulary", "embedding" + ) + ), + "final_norm/scale.weights": ( + pz.nx.NamedArray.wrap(converted["model.norm.weight"]).tag("embedding") + ), + "lm_head.weights": ( + pz.nx.NamedArray.wrap(converted["lm_head.weight"]).tag( + "vocabulary", "embedding" + ) ), } From 6b192f82952cf229ff8140d0df72dfc7a27c9b78 Mon Sep 17 00:00:00 2001 From: "Xiangming (Brian) Gu" <72553776+guxm2021@users.noreply.github.com> Date: Mon, 16 Jun 2025 10:07:04 +0100 Subject: [PATCH 13/15] Update gemma.py --- penzai/models/transformer/variants/gemma.py | 36 ++++++++------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index c1eae1d..c0d0bf1 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -131,11 +131,8 @@ def _make_attention_layers_types( attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(512),) * 5 - + ( - llamalike_common.AttentionTypeGlobalCausal(), - ), + + (llamalike_common.AttentionTypeGlobalCausal(),), num_layers=26, - ), use_qk_norm=True, use_post_attn_norm=True, @@ -154,11 +151,8 @@ def _make_attention_layers_types( attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) * 5 - + ( - llamalike_common.AttentionTypeGlobalCausal(), - ), + + (llamalike_common.AttentionTypeGlobalCausal(),), num_layers=34, - ), use_qk_norm=True, use_post_attn_norm=True, @@ -178,11 +172,8 @@ def _make_attention_layers_types( attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) * 5 - + ( - llamalike_common.AttentionTypeGlobalCausal(), - ), + + (llamalike_common.AttentionTypeGlobalCausal(),), num_layers=48, - ), use_qk_norm=True, use_post_attn_norm=True, @@ -204,11 +195,8 @@ def _make_attention_layers_types( attention_type=_make_attention_layers_types( pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) * 5 - + ( - llamalike_common.AttentionTypeGlobalCausal(), - ), + + (llamalike_common.AttentionTypeGlobalCausal(),), num_layers=62, - ), use_qk_norm=True, use_post_attn_norm=True, @@ -328,12 +316,16 @@ def gemma_from_pretrained_checkpoint( config, init_base_rng=None, name="transformer" ) parameter_mapping = { - "embedder.embeddings": pz.nx.NamedArray.wrap( - params["embedder"]["input_embedding"] - ).tag("vocabulary", "embedding"), - "final_norm/scale.weights": pz.nx.NamedArray.wrap( - 1 + params["final_norm"]["scale"] - ).tag("embedding"), + "embedder.embeddings": ( + pz.nx.NamedArray.wrap(params["embedder"]["input_embedding"]).tag( + "vocabulary", "embedding" + ) + ), + "final_norm/scale.weights": ( + pz.nx.NamedArray.wrap(1 + params["final_norm"]["scale"]).tag( + "embedding" + ) + ), } all_block_params = [] From 41f43df27e6635366646544eb32e686372aca860 Mon Sep 17 00:00:00 2001 From: Xiangming Gu Date: Mon, 16 Jun 2025 12:43:14 +0000 Subject: [PATCH 14/15] pyink fix --- penzai/models/transformer/variants/gemma.py | 16 ++++++--------- .../transformer/variants/llamalike_common.py | 20 ++++++++----------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index c0d0bf1..0b57a32 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -316,16 +316,12 @@ def gemma_from_pretrained_checkpoint( config, init_base_rng=None, name="transformer" ) parameter_mapping = { - "embedder.embeddings": ( - pz.nx.NamedArray.wrap(params["embedder"]["input_embedding"]).tag( - "vocabulary", "embedding" - ) - ), - "final_norm/scale.weights": ( - pz.nx.NamedArray.wrap(1 + params["final_norm"]["scale"]).tag( - "embedding" - ) - ), + "embedder.embeddings": pz.nx.NamedArray.wrap( + params["embedder"]["input_embedding"] + ).tag("vocabulary", "embedding"), + "final_norm/scale.weights": pz.nx.NamedArray.wrap( + 1 + params["final_norm"]["scale"] + ).tag("embedding"), } all_block_params = [] diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index 02870e0..b1247bc 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -666,18 +666,14 @@ def llamalike_from_huggingface_model( converted = {k: jax.dlpack.from_dlpack(v) for k, v in state_dict.items()} parameter_mapping = { - "embedder.embeddings": ( - pz.nx.NamedArray.wrap(converted["model.embed_tokens.weight"]).tag( - "vocabulary", "embedding" - ) - ), - "final_norm/scale.weights": ( - pz.nx.NamedArray.wrap(converted["model.norm.weight"]).tag("embedding") - ), - "lm_head.weights": ( - pz.nx.NamedArray.wrap(converted["lm_head.weight"]).tag( - "vocabulary", "embedding" - ) + "embedder.embeddings": pz.nx.NamedArray.wrap( + converted["model.embed_tokens.weight"] + ).tag("vocabulary", "embedding"), + "final_norm/scale.weights": pz.nx.NamedArray.wrap( + converted["model.norm.weight"] + ).tag("embedding"), + "lm_head.weights": pz.nx.NamedArray.wrap(converted["lm_head.weight"]).tag( + "vocabulary", "embedding" ), } From b3e31cac6c58e9ade6bbe7eba89fe4b49a3170f9 Mon Sep 17 00:00:00 2001 From: Xiangming Gu Date: Mon, 16 Jun 2025 12:55:05 +0000 Subject: [PATCH 15/15] fix pylint --- penzai/models/transformer/variants/gemma.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index 0b57a32..f5023ce 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -269,13 +269,10 @@ def gemma_from_pretrained_checkpoint( num_layers = 0 while f"layer_{num_layers}/mlp/linear" in params: num_layers += 1 - if ( + qk_norm = ( "layer_0/attn/_query_norm" in params and "layer_0/attn/_key_norm" in params - ): - qk_norm = True - else: - qk_norm = False + ) is_match = False for gemma_preset_name, kwargs in _GEMMA_PRESETS.items(): if kwargs["num_decoder_blocks"] == num_layers: