Skip to content

Enable OnDeviceSamplingConfig for compiler accuracy fix#37

Open
sdeeptan-aws wants to merge 1 commit intoaws-neuron:mainfrom
sdeeptan-aws:vaultgemma
Open

Enable OnDeviceSamplingConfig for compiler accuracy fix#37
sdeeptan-aws wants to merge 1 commit intoaws-neuron:mainfrom
sdeeptan-aws:vaultgemma

Conversation

@sdeeptan-aws
Copy link
Contributor

Description

Updated VaultGemma-1B contrib model with compiler accuracy fix, validated modeling code, and updated README. The model initially had 0% accuracy due to Neuron compiler optimizations causing numerical divergence. Pure PyTorch matched HuggingFace perfectly — the fix was enabling OnDeviceSamplingConfig which changes the XLA graph structure and prevents aggressive kernel fusions that destroyed numerical accuracy. Validation now achieves 100% token match.

Model Information

Model Name: VaultGemma-1B
Model Architecture: Decoder-only transformer (Gemma-style with (1+w) RMSNorm)
Purpose: Text generation

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)
    • Validates model accuracy with logit comparison
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types
    • Example Checkpoints: Links to compatible model checkpoints
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns

Optional Components

  • Unit Tests (CPU or Neuron-based)

Folder Structure

Confirm your contribution follows this structure:

/contrib/models/vaultgemma-1b/
    README.md
    /src
        modeling_vaultgemma.py
    /test
        /integration
            test_model.py

Testing

Model was compiled and tested end-to-end. Key debugging finding: pure PyTorch implementation matched HF with 0.99 correlation, but compiled Neuron model diverged to 0.61 correlation. Enabling OnDeviceSamplingConfig resolved the compiler optimization issue.

Test Results:

Test Without ODS With ODS
"The capital of France is" ' in' (wrong) ' Paris' (correct)
"The largest planet is" ' Saturn' (wrong) ' Jupiter' (correct)
Correlation 0.61 1.0
HF Match Rate 0% 100%

Compatibility

Tested with:

  • Instance Type(s): Trn1
  • Configuration: TP=1, batch_size=1, seq_len=128, bfloat16

Additional Information

  • Uses (1+w) RMSNorm pattern (Gemma-style): norm weights have mean near 0, applied as output = self._norm(x) * (1.0 + self.weight)
  • May have embedding normalization (sqrt(hidden_size))
  • OnDeviceSamplingConfig is required for correct compiled model accuracy — without it, XLA kernel fusions cause numerical divergence
  • Pure PyTorch implementation is correct; the issue is purely in the compilation path

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

@aws-neuron aws-neuron deleted a comment from sdeeptan-aws Feb 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments