Skip to content

Add LongRoPE and fix state dict conversion for Phi-3.5-mini-instruct#44

Open
sdeeptan-aws wants to merge 1 commit intoaws-neuron:mainfrom
sdeeptan-aws:phi3.5
Open

Add LongRoPE and fix state dict conversion for Phi-3.5-mini-instruct#44
sdeeptan-aws wants to merge 1 commit intoaws-neuron:mainfrom
sdeeptan-aws:phi3.5

Conversation

@sdeeptan-aws
Copy link
Contributor

Description

Updated Phi-3.5-mini-instruct contrib model with LongRoPE implementation (learned position-dependent scaling factors for 128k context), fused QKV and gate_up projection splitting in state dict conversion, and corrected o_proj mapping to avoid double-prefixing by preshard_hook. The model uses Phi3LongRoPEScaledRotaryEmbedding with separate short_factor (seq ≤ 4096) and long_factor (longer sequences) arrays instead of standard RoPE. Initial accuracy was 17.19% with repetition loops; after fixes, achieves 100% token match.

Model Information

Model Name: Phi-3.5-mini-instruct
Model Architecture: Decoder-only transformer (Phi-3 with LongRoPE, MHA, 32 heads, 32 layers, hidden_size=3072)
Purpose: Text generation / instruction following

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)
    • Multi-prompt integration test validating token match accuracy
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types
    • Example Checkpoints: Links to compatible model checkpoints
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns

Optional Components

  • Unit Tests (CPU or Neuron-based)

Folder Structure

/contrib/models/Phi-3.5-mini-instruct/
  README.md
  /src
    modeling_phi3.py
  /test
    /integration
      test_model.py

Testing

Model was compiled and tested with TP=2, batch_size=1, seq_len=128, bfloat16. Three key fixes validated:

  1. LongRoPE scaling: Implemented Phi3LongRoPEScaledRotaryEmbedding with learned short_factor and long_factor arrays. Standard RoPE uses fixed inverse frequencies; LongRoPE multiplies by position-dependent scaling factors and applies a sqrt(1 + log(scale)/log(orig_max_pos)) correction to cos/sin values.
  2. Fused weight splitting: qkv_proj.weight → split into qkv_proj.q_proj, qkv_proj.k_proj, qkv_proj.v_proj. gate_up_proj.weight → split into gate_proj and up_proj.
  3. o_proj double-prefix fix: Initially added o_proj.o_proj in conversion, but GroupQueryAttention_O.preshard_hook also adds the prefix, resulting in o_proj.o_proj.o_proj (broken). Fix: leave o_proj as-is and let preshard_hook handle it.

Test Results:

Test Status Result
Smoke Test ✅ PASS Model compiles and loads successfully
Token Matching ✅ PASS 100% match

Multi-Prompt Accuracy:

Prompt Match Rate
"The capital of France is" 100%

Compatibility

Tested with:

  • Instance Type(s): Trn1
  • Configuration: TP=2, batch_size=1, seq_len=128, bfloat16

Additional Information

  • LongRoPE ≠ standard RoPE: Check rope_scaling.type in config. LongRoPE uses learned per-dimension scaling factors stored in config.json, not fixed interpolation.
  • Scaling factors are learned: short_factor (48 values for seq ≤ 4096) and long_factor (48 values for longer sequences) are trained parameters, not computed at runtime.
  • preshard_hook handles o_proj: GroupQueryAttention_O.preshard_hook expects o_proj.weight and maps it to o_proj.o_proj.weight. Adding o_proj.o_proj in your conversion causes triple-nesting.
  • Fused QKV: Single qkv_proj layer concatenating Q, K, V — must be split during state dict conversion with correct head dimension slicing.
  • Fused gate_up: Single gate_up_proj layer concatenating gate and up projections — split at intermediate_size boundary.
  • MHA not GQA: 32 attention heads, 32 KV heads (full multi-head attention).

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments