Skip to content

Add SDPA decomposition for Metal backend with unsupported head_dim#221

Open
seyeong-han wants to merge 1 commit intohuggingface:mainfrom
seyeong-han:gemma3-metal-sdpa
Open

Add SDPA decomposition for Metal backend with unsupported head_dim#221
seyeong-han wants to merge 1 commit intohuggingface:mainfrom
seyeong-han:gemma3-metal-sdpa

Conversation

@seyeong-han
Copy link

The Metal SDPA kernel only supports head_dim in {64, 96, 128}. Models like Gemma3 (head_dim=256) crash at runtime. This decomposes SDPA into matmul + softmax when the model's head_dim is unsupported, following the pattern from voxtral_realtime's Metal export.

Changes:

  • metal.py: Add _sdpa_decomposition and _linear_bias_decomposition, applied via run_decompositions() before lowering. Conditional on head_dim not in {64, 96, 128}. Force use_custom_sdpa=False for Metal.
  • integrations.py: Guard get_custom_sdpa_for_ring_kv_cache() and RemoveRedundantTransposes imports behind use_custom_sdpa check to avoid triggering torchao import chain when custom SDPA is not used.

The Metal SDPA kernel only supports head_dim in {64, 96, 128}. Models
like Gemma3 (head_dim=256) crash at runtime. This decomposes SDPA into
matmul + softmax when the model's head_dim is unsupported, following
the pattern from voxtral_realtime's Metal export.

Changes:
- metal.py: Add _sdpa_decomposition and _linear_bias_decomposition,
  applied via run_decompositions() before lowering. Conditional on
  head_dim not in {64, 96, 128}. Force use_custom_sdpa=False for Metal.
- integrations.py: Guard get_custom_sdpa_for_ring_kv_cache() and
  RemoveRedundantTransposes imports behind use_custom_sdpa check to
  avoid triggering torchao import chain when custom SDPA is not used.

This PR was authored with the assistance of Claude.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant