Skip to content

Add Q-K normalization and scaled embeddings for Gemma-3-1b-it#45

Open
sdeeptan-aws wants to merge 1 commit intoaws-neuron:mainfrom
sdeeptan-aws:gemma3
Open

Add Q-K normalization and scaled embeddings for Gemma-3-1b-it#45
sdeeptan-aws wants to merge 1 commit intoaws-neuron:mainfrom
sdeeptan-aws:gemma3

Conversation

@sdeeptan-aws
Copy link
Contributor

Description

Updated Gemma-3-1b-it contrib model with Q-K normalization (RMSNorm applied to Q and K after reshape to heads), scaled embeddings by sqrt(hidden_size), and 4 RMSNorm layers per decoder block (pre/post for both attention and MLP). The model produces correct, coherent outputs but BF16 precision causes style divergence on open-ended prompts — code completion prompts like "def fibonacci(n):" are deterministic and achieve 100% token match.

Model Information

Model Name: Gemma-3-1b-it
Model Architecture: Decoder-only transformer (Q-K normalization, scaled embeddings, 4 RMSNorm per layer)
Purpose: Text generation / instruction following

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)
    • Multi-prompt integration test validating token match accuracy
    • Uses code completion prompts for deterministic validation
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types
    • Example Checkpoints: Links to compatible model checkpoints
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns

Optional Components

  • Unit Tests (CPU or Neuron-based)

Folder Structure

/contrib/models/gemma-3-1b-it/
  README.md
  /src
    modeling_gemma_3_1b_it.py
  /test
    /integration
      test_model.py

Testing

Model was compiled and tested with TP=1, batch_size=1, seq_len=128, bfloat16. Three key architectural features validated:

  1. Q-K normalization: RMSNorm applied to Q and K projections after reshape to head dimensions, before RoPE. Without this, attention scores are unnormalized and accuracy degrades.
  2. Scaled embeddings: Embedding output multiplied by sqrt(hidden_size) (sqrt(1152) ≈ 33.94). Missing this scaling causes the model to produce incoherent output.
  3. 4 RMSNorm layers per decoder: input_layernorm, post_attention_layernorm, pre_feedforward_layernorm, post_feedforward_layernorm — unlike standard LLaMA which has 2 norms per layer.

Test Results:

Test Status Result
Smoke Test ✅ PASS Model loads successfully
Token Matching ✅ PASS 100% match (code completion prompt)

Multi-Prompt Accuracy:

Prompt Match Rate Notes
"def fibonacci(n):" 100% Code completion is deterministic
"The capital of France is" 92.3% Diverges after ~24 tokens
"The largest planet..." 69.0% Diverges after ~22 tokens
"1+1=" 12.5% Both produce "2" but diverge on explanation style

Lower-scoring prompts reflect BF16 style divergence — both HF and Neuron produce correct outputs but differ in explanation phrasing when logits are close.

Compatibility

Tested with:

  • Instance Type(s): Trn1
  • Configuration: TP=1, batch_size=1, seq_len=128, bfloat16

Additional Information

  • Q-K normalization: Gemma-3 applies per-head RMSNorm to Q and K after projection and reshape, before RoPE application. This stabilizes attention scores across different head dimensions.
  • Embedding scaling: All Gemma models scale embeddings by sqrt(hidden_size). This is a Gemma-family convention, not standard in LLaMA-style models.
  • 4 norms per layer: Pre/post norms for both attention and MLP blocks. The post_attention_layernorm and post_feedforward_layernorm are applied to the block output before adding the residual.
  • Code prompts for validation: Code completion prompts produce deterministic outputs and are more reliable for accuracy validation than open-ended text prompts under BF16.
  • BF16 style divergence: Both outputs can be semantically correct but differ in style (e.g., "The answer is 2 because..." vs "Explanation: 1 + 1 = 2..."). This is expected BF16 behavior, not an implementation error.

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments