diff --git a/docs/guides/howto_reference.md b/docs/guides/howto_reference.md index cddf53c..e2c1c21 100644 --- a/docs/guides/howto_reference.md +++ b/docs/guides/howto_reference.md @@ -217,9 +217,9 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen ## Loading Pretrained Models -### Loading Gemma or Gemma 2 +### Loading Gemma (1, 2, or 3) -Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2)) into the correct form. You can load it using: +Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2), [Gemma 3](https://www.kaggle.com/models/google/gemma-3)) into the correct form. You can load it using: ```python import kagglehub @@ -236,13 +236,20 @@ flax_params_dict = checkpointer.restore(ckpt_path) model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict) ``` -To load Gemma 2, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use: +To load Gemma 2/3, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use: ```python weights_dir = kagglehub.model_download('google/gemma-2/flax/gemma2-9b') ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt') ``` +To load the Gemma 3 4B model, you can use: + +```python +weights_dir = kagglehub.model_download('google/gemma-3/flax/gemma3-4b') +ckpt_path = os.path.join(weights_dir, 'gemma3_4b_pt') +``` + See the "Model Variations" section on the Kaggle model pages for details about the names and paths for each checkpoint. (You may also need to create a Kaggle account and request access to each model before you can download the checkpoints.) If you are using multiple accelerator devices (e.g. for a TPU v2 Colab kernel), you may want to shard the parameters over the devices while loading them. To do so, you can pass a sharding specification to `orbax.checkpoint`. For instance, to shard over the last axis of every parameter, you can use diff --git a/penzai/models/transformer/variants/gemma.py b/penzai/models/transformer/variants/gemma.py index 1b8289f..f5023ce 100644 --- a/penzai/models/transformer/variants/gemma.py +++ b/penzai/models/transformer/variants/gemma.py @@ -14,13 +14,14 @@ """The Gemma architecture transformer variant. -Supports both the Gemma 1 and Gemma 2 architectures. Based on the Flax -reference implementation at https://github.com/google-deepmind/gemma. +Supports all the Gemma 1, Gemma 2 and Gemma 3 architectures. Based on the +Flax reference implementation at https://github.com/google-deepmind/gemma. See the Gemma technical reports for more information: * Gemma 1: https://arxiv.org/abs/2403.08295 * Gemma 2: https://arxiv.org/abs/2408.00118 +* Gemma 3: https://arxiv.org/abs/2503.19786 """ from __future__ import annotations @@ -33,6 +34,20 @@ from penzai.models.transformer.variants import llamalike_common +def _make_attention_layers_types( + pattern: tuple[llamalike_common.AttentionType, ...], + *, + num_layers: int, +) -> tuple[llamalike_common.AttentionType, ...]: + """Returns the list of attention types for every layers.""" + + pattern_size = len(pattern) + out = pattern * (num_layers // pattern_size) + if num_layers % pattern_size != 0: + out += pattern[: num_layers % pattern_size] + return tuple(out) + + _GEMMA_PRESETS = { "gemma_2b": dict( num_decoder_blocks=18, @@ -105,6 +120,91 @@ final_logit_softcap=30.0, attn_logits_soft_cap=50.0, ), + "gemma3_1b": dict( + num_decoder_blocks=26, + vocab_size=262_144, + num_kv_heads=1, + query_head_multiplier=4, + embedding_dim=1152, + projection_dim=256, + mlp_hidden_dim=6 * 1152, + attention_type=_make_attention_layers_types( + pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(512),) + * 5 + + (llamalike_common.AttentionTypeGlobalCausal(),), + num_layers=26, + ), + use_qk_norm=True, + use_post_attn_norm=True, + use_post_ffw_norm=True, + rope_wavelength=1_000_000, + local_rope_wavelength=10_000, + ), + "gemma3_4b": dict( + num_decoder_blocks=34, + vocab_size=262_144, + num_kv_heads=4, + query_head_multiplier=2, + embedding_dim=2560, + projection_dim=256, + mlp_hidden_dim=2560 * 8 // 2, + attention_type=_make_attention_layers_types( + pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) + * 5 + + (llamalike_common.AttentionTypeGlobalCausal(),), + num_layers=34, + ), + use_qk_norm=True, + use_post_attn_norm=True, + use_post_ffw_norm=True, + global_scale_factor=8.0, + rope_wavelength=1_000_000, + local_rope_wavelength=10_000, + ), + "gemma3_12b": dict( + num_decoder_blocks=48, + vocab_size=262_144, + num_kv_heads=8, + query_head_multiplier=2, + embedding_dim=30 * 128, + projection_dim=256, + mlp_hidden_dim=8 * 30 * 128 // 2, + attention_type=_make_attention_layers_types( + pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) + * 5 + + (llamalike_common.AttentionTypeGlobalCausal(),), + num_layers=48, + ), + use_qk_norm=True, + use_post_attn_norm=True, + use_post_ffw_norm=True, + global_scale_factor=8.0, + rope_wavelength=1_000_000, + local_rope_wavelength=10_000, + ), + "gemma3_27b": dict( + num_decoder_blocks=62, + vocab_size=262_144, + num_kv_heads=16, + query_head_multiplier=2, + embedding_dim=5376, + projection_dim=128, + mlp_hidden_dim=5376 * 8 // 2, + # query scaling factor: 1/sqrt(embedding_dim / num_query_heads) + query_scaling_factor=(5376 // 32) ** -0.5, + attention_type=_make_attention_layers_types( + pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),) + * 5 + + (llamalike_common.AttentionTypeGlobalCausal(),), + num_layers=62, + ), + use_qk_norm=True, + use_post_attn_norm=True, + use_post_ffw_norm=True, + global_scale_factor=8.0, + rope_wavelength=1_000_000, + local_rope_wavelength=10_000, + ), } _NEEDS_GATING_TRANSPOSE = { "gemma_2b": False, @@ -112,6 +212,10 @@ "gemma2_2b": False, "gemma2_9b": True, "gemma2_27b": True, + "gemma3_1b": True, + "gemma3_4b": True, + "gemma3_12b": True, + "gemma3_27b": True, } @@ -120,7 +224,16 @@ def gemma_from_pretrained_checkpoint( upcast_activations_to_float32: bool = False, use_layer_stack: bool = False, preset_name: Literal[ - "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto" + "gemma_2b", + "gemma_7b", + "gemma2_2b", + "gemma2_9b", + "gemma2_27b", + "gemma3_1b", + "gemma3_4b", + "gemma3_12b", + "gemma3_27b", + "auto", ] = "auto", ) -> model_parts.TransformerLM: """Builds a Gemma model from a pretrained checkpoint. @@ -144,7 +257,8 @@ def gemma_from_pretrained_checkpoint( without consuming additional memory for parameters. use_layer_stack: Whether to use a layer stack for the decoder blocks. preset_name: Preset name, used to determine model config. If "auto", uses - the number of layers in the checkpoint to determine the configuration. + the number of layers and whether the model needs qk norm in the checkpoint + to determine the configuration. Returns: A Transformer model containing the loaded parameters. @@ -155,15 +269,27 @@ def gemma_from_pretrained_checkpoint( num_layers = 0 while f"layer_{num_layers}/mlp/linear" in params: num_layers += 1 - preset_by_num_layers = { - kwargs["num_decoder_blocks"]: preset_name - for preset_name, kwargs in _GEMMA_PRESETS.items() - } - if num_layers not in preset_by_num_layers: + qk_norm = ( + "layer_0/attn/_query_norm" in params + and "layer_0/attn/_key_norm" in params + ) + is_match = False + for gemma_preset_name, kwargs in _GEMMA_PRESETS.items(): + if kwargs["num_decoder_blocks"] == num_layers: + if qk_norm and "use_qk_norm" in kwargs: + if kwargs["use_qk_norm"]: + is_match = True + preset_name = gemma_preset_name + break + if (not qk_norm) and ("use_qk_norm" not in kwargs): + is_match = True + preset_name = gemma_preset_name + break + if not is_match: raise ValueError( - f"Could not determine preset for model with {num_layers} layers." + f"Could not determine preset for model with {num_layers} layers and" + f" qk norm {qk_norm}." ) - preset_name = preset_by_num_layers[num_layers] preset_kwargs = _GEMMA_PRESETS[preset_name] preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name] @@ -179,7 +305,6 @@ def gemma_from_pretrained_checkpoint( **preset_kwargs, parameter_dtype=parameter_dtype, mlp_variant="geglu_approx", - rope_wavelength=10_000, tie_embedder_and_logits=True, activation_dtype=activation_dtype, use_layer_stack=use_layer_stack, @@ -207,6 +332,19 @@ def gemma_from_pretrained_checkpoint( 1 + params[f"layer_{i}/pre_attention_norm"]["scale"] ).tag("embedding") ) + # Add qk norm if needed + if config.use_qk_norm: + cur_block_params["attention/query_norm/scale.weights"] = ( + pz.nx.NamedArray.wrap( + 1 + params[f"layer_{i}/attn/_query_norm"]["scale"] + ).tag("projection") + ) + cur_block_params["attention/key_norm/scale.weights"] = ( + pz.nx.NamedArray.wrap( + 1 + params[f"layer_{i}/attn/_key_norm"]["scale"] + ).tag("projection") + ) + if config.use_post_attn_norm: cur_block_params["post_attention_norm/scale.weights"] = ( pz.nx.NamedArray.wrap( diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index 1307b23..b1247bc 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -84,7 +84,8 @@ class LlamalikeTransformerConfig: tie_embedder_and_logits: Whether to tie the weights of the input token embedding and output logit layers. If True, also scales down input token embeddings by sqrt(embedding_dim). (This is used by Gemma.) - rope_wavelength: Wavelength for RoPE layers. + rope_wavelength: Wavelength for global RoPE layers (and for local RoPE + layers if local_rope_wavelength is not set). rms_norm_eps: Epsilon for RMSNorm layers. attention_type: A single attention type or sequence of per-layer attention types. If a sequence, its length should evenly divide the number of @@ -102,6 +103,12 @@ class LlamalikeTransformerConfig: parameter_dtype: Floating dtype to use for all parameters. activation_dtype: Floating dtype to use for activations and KV cache tables. use_layer_stack: Whether to stack the blocks together using a LayerStack. + use_qk_norm: Whether to use QK normalization. + global_scale_factor: Scale factor for the global RoPE layers (scale factor + for the local RoPE layers is set as 1.0 by default). + local_rope_wavelength: Wavelength for the local RoPE layers. If None, local + RoPE layers will use the same wavelength as global RoPE layers + (config.rope_wavelength). """ num_kv_heads: int @@ -126,6 +133,9 @@ class LlamalikeTransformerConfig: parameter_dtype: jax.typing.DTypeLike = jnp.float32 activation_dtype: jax.typing.DTypeLike = jnp.float32 use_layer_stack: bool = False + use_qk_norm: bool = False + global_scale_factor: float | None = None + local_rope_wavelength: float | None = None def build_llamalike_feedforward( @@ -261,10 +271,22 @@ def build_llamalike_attention( sliding_window_size=attention_type.window_size, masked_out_value=masked_out_value, ) + # Decide which wavelength to use for local RoPE. + if config.local_rope_wavelength is not None: + wavelength = config.local_rope_wavelength + else: + wavelength = config.rope_wavelength + scale_factor = 1.0 elif isinstance(attention_type, AttentionTypeGlobalCausal): attn_masker = pz.nn.ApplyCausalAttentionMask( masked_out_value=masked_out_value, ) + wavelength = config.rope_wavelength + # Decide which scale factor to use for global RoPE. + if config.global_scale_factor is not None: + scale_factor = config.global_scale_factor + else: + scale_factor = 1.0 else: raise ValueError(f"Unsupported attention type {attention_type}") @@ -290,42 +312,74 @@ def build_llamalike_attention( pz.nn.Softmax("kv_seq"), ]) + # add qk norm if needed in the module of input_to_query sublayers + input_to_query_sublayers = [ + pz.nn.Linear.from_config( + name=f"{name}/query", + init_base_rng=init_base_rng, + input_axes={"embedding": embedding_dim}, + output_axes={ + **common_head_axes, + **query_only_head_axes, + "projection": projection_dim, + }, + dtype=config.parameter_dtype, + ), + ] + if config.use_qk_norm: + input_to_query_sublayers.append( + pz.nn.RMSLayerNorm.from_config( + name=f"{name}/query_norm", + init_base_rng=init_base_rng, + across_axes={"projection": config.projection_dim}, + dtype=config.parameter_dtype, + epsilon=config.rms_norm_eps, + ), + ) + input_to_query_sublayers.extend([ + pz.nn.ApplyRoPE( + positions_input_name="token_positions", + embedding_axis="projection", + max_wavelength=wavelength, + scale_factor=scale_factor, + ), + pz.nn.ConstantRescale( + by=jnp.array(query_scaling_factor, dtype=config.activation_dtype) + ), + ]) + + # add qk norm if needed in the module of input_to_key sublayers + input_to_key_sublayers = [ + pz.nn.Linear.from_config( + name=f"{name}/key", + init_base_rng=init_base_rng, + input_axes={"embedding": embedding_dim}, + output_axes={**common_head_axes, "projection": projection_dim}, + dtype=config.parameter_dtype, + ), + ] + if config.use_qk_norm: + input_to_key_sublayers.append( + pz.nn.RMSLayerNorm.from_config( + name=f"{name}/key_norm", + init_base_rng=init_base_rng, + across_axes={"projection": config.projection_dim}, + dtype=config.parameter_dtype, + epsilon=config.rms_norm_eps, + ), + ) + input_to_key_sublayers.append( + pz.nn.ApplyRoPE( + positions_input_name="token_positions", + embedding_axis="projection", + max_wavelength=wavelength, + scale_factor=scale_factor, + ), + ) + return pz.nn.Attention( - input_to_query=pz.nn.Sequential([ - pz.nn.Linear.from_config( - name=f"{name}/query", - init_base_rng=init_base_rng, - input_axes={"embedding": embedding_dim}, - output_axes={ - **common_head_axes, - **query_only_head_axes, - "projection": projection_dim, - }, - dtype=config.parameter_dtype, - ), - pz.nn.ApplyRoPE( - positions_input_name="token_positions", - embedding_axis="projection", - max_wavelength=config.rope_wavelength, - ), - pz.nn.ConstantRescale( - by=jnp.array(query_scaling_factor, dtype=config.activation_dtype) - ), - ]), - input_to_key=pz.nn.Sequential([ - pz.nn.Linear.from_config( - name=f"{name}/key", - init_base_rng=init_base_rng, - input_axes={"embedding": embedding_dim}, - output_axes={**common_head_axes, "projection": projection_dim}, - dtype=config.parameter_dtype, - ), - pz.nn.ApplyRoPE( - positions_input_name="token_positions", - embedding_axis="projection", - max_wavelength=config.rope_wavelength, - ), - ]), + input_to_query=pz.nn.Sequential(input_to_query_sublayers), + input_to_key=pz.nn.Sequential(input_to_key_sublayers), input_to_value=pz.nn.Sequential([ pz.nn.Linear.from_config( name=f"{name}/value", diff --git a/penzai/nn/embeddings.py b/penzai/nn/embeddings.py index 0905d20..0098511 100644 --- a/penzai/nn/embeddings.py +++ b/penzai/nn/embeddings.py @@ -225,11 +225,17 @@ class ApplyRoPE(layer_base.Layer): each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis. + scale_factor: The scale factor to use for the positional embeddings (used by + Gemma3 models). """ embedding_axis: str = dataclasses.field(metadata={"pytree_node": False}) max_wavelength: float = dataclasses.field(metadata={"pytree_node": False}) positions_input_name: str = dataclasses.field(metadata={"pytree_node": False}) + scale_factor: float = dataclasses.field( + default=1.0, + metadata={"pytree_node": False}, + ) def _apply_1d(self, input_slice: jax.Array, position: jax.Array) -> jax.Array: """Apply RoPE to a one-dimensional JAX array.""" @@ -242,6 +248,9 @@ def _apply_1d(self, input_slice: jax.Array, position: jax.Array) -> jax.Array: # Since we're assuming `timescale` is a vector and `position` is a scalar, # we don't need any axis alignment. sinusoid_inp = position / timescale + if self.scale_factor < 1.0: + raise ValueError("scale_factor must be >= 1.0, got {scale_factor}") + sinusoid_inp = sinusoid_inp / self.scale_factor sin = jnp.sin(sinusoid_inp) cos = jnp.cos(sinusoid_inp) first_half, second_half = jnp.split(input_slice, 2) @@ -298,12 +307,18 @@ class ApplyRoPEToSubset(layer_base.Layer): each token in the sequence. This side input should be provided as an integer array that is broadcastable with the input, and which does NOT include the embedding axis. + scale_factor: The scale factor to use for the positional embeddings (used by + Gemma 3 models). """ embedding_axis: str = dataclasses.field(metadata={"pytree_node": False}) max_wavelength: float = dataclasses.field(metadata={"pytree_node": False}) rope_subset_size: int = dataclasses.field(metadata={"pytree_node": False}) positions_input_name: str = dataclasses.field(metadata={"pytree_node": False}) + scale_factor: float = dataclasses.field( + default=1.0, + metadata={"pytree_node": False}, + ) def __call__( self, inputs: named_axes.NamedArray, **side_inputs @@ -319,6 +334,7 @@ def __call__( embedding_axis=self.embedding_axis, max_wavelength=self.max_wavelength, positions_input_name=self.positions_input_name, + scale_factor=self.scale_factor, ) rotated_result = rotator(rotary_input, **side_inputs) return named_axes.concatenate(