Skip to content
13 changes: 10 additions & 3 deletions docs/guides/howto_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen

## Loading Pretrained Models

### Loading Gemma or Gemma 2
### Loading Gemma (1, 2, or 3)

Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2)) into the correct form. You can load it using:
Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2), [Gemma 3](https://www.kaggle.com/models/google/gemma-3)) into the correct form. You can load it using:

```python
import kagglehub
Expand All @@ -236,13 +236,20 @@ flax_params_dict = checkpointer.restore(ckpt_path)
model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict)
```

To load Gemma 2, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use:
To load Gemma 2/3, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use:

```python
weights_dir = kagglehub.model_download('google/gemma-2/flax/gemma2-9b')
ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt')
```

To load the Gemma 3 4B model, you can use:

```python
weights_dir = kagglehub.model_download('google/gemma-3/flax/gemma3-4b')
ckpt_path = os.path.join(weights_dir, 'gemma3_4b_pt')
```

See the "Model Variations" section on the Kaggle model pages for details about the names and paths for each checkpoint. (You may also need to create a Kaggle account and request access to each model before you can download the checkpoints.)

If you are using multiple accelerator devices (e.g. for a TPU v2 Colab kernel), you may want to shard the parameters over the devices while loading them. To do so, you can pass a sharding specification to `orbax.checkpoint`. For instance, to shard over the last axis of every parameter, you can use
Expand Down
162 changes: 150 additions & 12 deletions penzai/models/transformer/variants/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

"""The Gemma architecture transformer variant.

Supports both the Gemma 1 and Gemma 2 architectures. Based on the Flax
reference implementation at https://github.com/google-deepmind/gemma.
Supports all the Gemma 1, Gemma 2 and Gemma 3 architectures. Based on the
Flax reference implementation at https://github.com/google-deepmind/gemma.

See the Gemma technical reports for more information:

* Gemma 1: https://arxiv.org/abs/2403.08295
* Gemma 2: https://arxiv.org/abs/2408.00118
* Gemma 3: https://arxiv.org/abs/2503.19786
"""

from __future__ import annotations
Expand All @@ -33,6 +34,20 @@
from penzai.models.transformer.variants import llamalike_common


def _make_attention_layers_types(
pattern: tuple[llamalike_common.AttentionType, ...],
*,
num_layers: int,
) -> tuple[llamalike_common.AttentionType, ...]:
"""Returns the list of attention types for every layers."""

pattern_size = len(pattern)
out = pattern * (num_layers // pattern_size)
if num_layers % pattern_size != 0:
out += pattern[: num_layers % pattern_size]
return tuple(out)


_GEMMA_PRESETS = {
"gemma_2b": dict(
num_decoder_blocks=18,
Expand Down Expand Up @@ -105,13 +120,102 @@
final_logit_softcap=30.0,
attn_logits_soft_cap=50.0,
),
"gemma3_1b": dict(
num_decoder_blocks=26,
vocab_size=262_144,
num_kv_heads=1,
query_head_multiplier=4,
embedding_dim=1152,
projection_dim=256,
mlp_hidden_dim=6 * 1152,
attention_type=_make_attention_layers_types(
pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(512),)
* 5
+ (llamalike_common.AttentionTypeGlobalCausal(),),
num_layers=26,
),
use_qk_norm=True,
use_post_attn_norm=True,
use_post_ffw_norm=True,
rope_wavelength=1_000_000,
local_rope_wavelength=10_000,
),
"gemma3_4b": dict(
num_decoder_blocks=34,
vocab_size=262_144,
num_kv_heads=4,
query_head_multiplier=2,
embedding_dim=2560,
projection_dim=256,
mlp_hidden_dim=2560 * 8 // 2,
attention_type=_make_attention_layers_types(
pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),)
* 5
+ (llamalike_common.AttentionTypeGlobalCausal(),),
num_layers=34,
),
use_qk_norm=True,
use_post_attn_norm=True,
use_post_ffw_norm=True,
global_scale_factor=8.0,
rope_wavelength=1_000_000,
local_rope_wavelength=10_000,
),
"gemma3_12b": dict(
num_decoder_blocks=48,
vocab_size=262_144,
num_kv_heads=8,
query_head_multiplier=2,
embedding_dim=30 * 128,
projection_dim=256,
mlp_hidden_dim=8 * 30 * 128 // 2,
attention_type=_make_attention_layers_types(
pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),)
* 5
+ (llamalike_common.AttentionTypeGlobalCausal(),),
num_layers=48,
),
use_qk_norm=True,
use_post_attn_norm=True,
use_post_ffw_norm=True,
global_scale_factor=8.0,
rope_wavelength=1_000_000,
local_rope_wavelength=10_000,
),
"gemma3_27b": dict(
num_decoder_blocks=62,
vocab_size=262_144,
num_kv_heads=16,
query_head_multiplier=2,
embedding_dim=5376,
projection_dim=128,
mlp_hidden_dim=5376 * 8 // 2,
# query scaling factor: 1/sqrt(embedding_dim / num_query_heads)
query_scaling_factor=(5376 // 32) ** -0.5,
attention_type=_make_attention_layers_types(
pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),)
* 5
+ (llamalike_common.AttentionTypeGlobalCausal(),),
num_layers=62,
),
use_qk_norm=True,
use_post_attn_norm=True,
use_post_ffw_norm=True,
global_scale_factor=8.0,
rope_wavelength=1_000_000,
local_rope_wavelength=10_000,
),
}
_NEEDS_GATING_TRANSPOSE = {
"gemma_2b": False,
"gemma_7b": False,
"gemma2_2b": False,
"gemma2_9b": True,
"gemma2_27b": True,
"gemma3_1b": True,
"gemma3_4b": True,
"gemma3_12b": True,
"gemma3_27b": True,
}


Expand All @@ -120,7 +224,16 @@ def gemma_from_pretrained_checkpoint(
upcast_activations_to_float32: bool = False,
use_layer_stack: bool = False,
preset_name: Literal[
"gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto"
"gemma_2b",
"gemma_7b",
"gemma2_2b",
"gemma2_9b",
"gemma2_27b",
"gemma3_1b",
"gemma3_4b",
"gemma3_12b",
"gemma3_27b",
"auto",
] = "auto",
) -> model_parts.TransformerLM:
"""Builds a Gemma model from a pretrained checkpoint.
Expand All @@ -144,7 +257,8 @@ def gemma_from_pretrained_checkpoint(
without consuming additional memory for parameters.
use_layer_stack: Whether to use a layer stack for the decoder blocks.
preset_name: Preset name, used to determine model config. If "auto", uses
the number of layers in the checkpoint to determine the configuration.
the number of layers and whether the model needs qk norm in the checkpoint
to determine the configuration.

Returns:
A Transformer model containing the loaded parameters.
Expand All @@ -155,15 +269,27 @@ def gemma_from_pretrained_checkpoint(
num_layers = 0
while f"layer_{num_layers}/mlp/linear" in params:
num_layers += 1
preset_by_num_layers = {
kwargs["num_decoder_blocks"]: preset_name
for preset_name, kwargs in _GEMMA_PRESETS.items()
}
if num_layers not in preset_by_num_layers:
qk_norm = (
"layer_0/attn/_query_norm" in params
and "layer_0/attn/_key_norm" in params
)
is_match = False
for gemma_preset_name, kwargs in _GEMMA_PRESETS.items():
if kwargs["num_decoder_blocks"] == num_layers:
if qk_norm and "use_qk_norm" in kwargs:
if kwargs["use_qk_norm"]:
is_match = True
preset_name = gemma_preset_name
break
if (not qk_norm) and ("use_qk_norm" not in kwargs):
is_match = True
preset_name = gemma_preset_name
break
if not is_match:
raise ValueError(
f"Could not determine preset for model with {num_layers} layers."
f"Could not determine preset for model with {num_layers} layers and"
f" qk norm {qk_norm}."
)
preset_name = preset_by_num_layers[num_layers]

preset_kwargs = _GEMMA_PRESETS[preset_name]
preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name]
Expand All @@ -179,7 +305,6 @@ def gemma_from_pretrained_checkpoint(
**preset_kwargs,
parameter_dtype=parameter_dtype,
mlp_variant="geglu_approx",
rope_wavelength=10_000,
tie_embedder_and_logits=True,
activation_dtype=activation_dtype,
use_layer_stack=use_layer_stack,
Expand Down Expand Up @@ -207,6 +332,19 @@ def gemma_from_pretrained_checkpoint(
1 + params[f"layer_{i}/pre_attention_norm"]["scale"]
).tag("embedding")
)
# Add qk norm if needed
if config.use_qk_norm:
cur_block_params["attention/query_norm/scale.weights"] = (
pz.nx.NamedArray.wrap(
1 + params[f"layer_{i}/attn/_query_norm"]["scale"]
).tag("projection")
)
cur_block_params["attention/key_norm/scale.weights"] = (
pz.nx.NamedArray.wrap(
1 + params[f"layer_{i}/attn/_key_norm"]["scale"]
).tag("projection")
)

if config.use_post_attn_norm:
cur_block_params["post_attention_norm/scale.weights"] = (
pz.nx.NamedArray.wrap(
Expand Down
Loading