Skip to content

Add ShardedRMSNorm for Q-K normalization under tensor parallelism#47

Open
sdeeptan-aws wants to merge 1 commit intoaws-neuron:mainfrom
sdeeptan-aws:olmo2
Open

Add ShardedRMSNorm for Q-K normalization under tensor parallelism#47
sdeeptan-aws wants to merge 1 commit intoaws-neuron:mainfrom
sdeeptan-aws:olmo2

Conversation

@sdeeptan-aws
Copy link
Contributor

Description

Updated OLMo-2-1124-7B contrib model with ShardedRMSNorm for Q-K normalization under tensor parallelism, post-layer normalization architecture (RMSNorm after attention/MLP, not before), and correct Q-K norm placement before head reshape. The critical fix was computing RMSNorm variance over the full hidden dimension (4096) rather than the sharded dimension (512 with TP=8) — naive TP implementation uses an 8x smaller denominator, causing Q/K values to differ by up to 1.64. Achieves 100% token match with TP=8.

Model Information

Model Name: OLMo-2-1124-7B
Model Architecture: Decoder-only transformer (~7B params, post-layer RMSNorm, Q-K normalization, RoPE theta=500000)
Purpose: Text generation

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)
    • Token match accuracy validation
    • Performance metrics (TTFT, throughput)
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types
    • Example Checkpoints: Links to compatible model checkpoints
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns

Optional Components

  • Unit Tests (CPU or Neuron-based)

Folder Structure

/contrib/models/OLMo-2-1124-7B/
  README.md
  /src
    modeling_olmo2.py
  /test
    /integration
      test_model.py

Testing

Model was compiled and tested with TP=8, batch_size=1, seq_len=128, bfloat16. Two key architectural features validated:

  1. ShardedRMSNorm for Q-K normalization: OLMo-2 applies RMSNorm to Q and K projections before reshaping to heads. With TP=8, each rank sees 512 elements but variance must be computed over all 4096. ShardedRMSNorm computes local sum of squares, all-reduces across TP ranks via reduce_from_tensor_model_parallel_region, then divides by the full dimension size. Without this, Q/K values differ by up to 1.64 and accuracy drops to 0%.
  2. Post-layer normalization: RMSNorm applied after attention and MLP outputs (not before like LLaMA). This is an OLMo-2-specific architectural choice that affects the residual connection pattern.

Test Results:

Test Status Result
Smoke Test ✅ PASS Model loads successfully
Token Matching ✅ PASS 100% match
TTFT (P50) ✅ PASS ~55ms (threshold: 100ms)
Throughput ✅ PASS ~18 tok/s (threshold: 10 tok/s)

Compatibility

Tested with:

  • Instance Type(s): Trn1
  • Configuration: TP=8, batch_size=1, seq_len=128, bfloat16

Additional Information

  • TP changes reduction semantics: Operations like mean/variance that reduce over sharded dimensions need special handling. Naive RMSNorm computes sum(x²) / 512 per rank instead of sum(x²) / 4096 globally — an 8x error in variance.
  • Use reduce_from_tensor_model_parallel_region: This is XLA-compatible, unlike raw torch.distributed.all_reduce. Required for correct cross-rank variance computation.
  • Post-layer norm ≠ pre-layer norm: OLMo-2 applies RMSNorm after the attention/MLP block output, before adding to the residual. LLaMA applies it before the block input.
  • Q-K norm before head reshape: RMSNorm is applied to the full Q/K projection output (shape [batch, seq, hidden_size]), not per-head. This is what makes TP sharding problematic.
  • Test with TP > 1 early: TP=1 tests pass perfectly and miss sharding-related bugs entirely.

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments